mirror of
https://github.com/apache/lucene.git
synced 2025-02-28 13:29:26 +00:00
LUCENE-10577: enable quantization of HNSW vectors to 8 bits (#1054)
* LUCENE-10577: enable supplying, storing, and comparing HNSW vectors with 8 bit precision
This commit is contained in:
parent
59a0917e25
commit
a693fe819b
@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfos;
|
|||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
import org.apache.lucene.index.SegmentInfo;
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataOutput;
|
import org.apache.lucene.store.DataOutput;
|
||||||
@ -214,6 +215,7 @@ public final class Lucene60FieldInfosFormat extends FieldInfosFormat {
|
|||||||
pointIndexDimensionCount,
|
pointIndexDimensionCount,
|
||||||
pointNumBytes,
|
pointNumBytes,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
isSoftDeletesField);
|
isSoftDeletesField);
|
||||||
} catch (IllegalStateException e) {
|
} catch (IllegalStateException e) {
|
||||||
|
@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
|||||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
* See the License for the specific language governing permissions and
|
* See the License for the specific language governing permissions and
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
package org.apache.lucene.codecs.lucene90;
|
package org.apache.lucene.backward_codecs.lucene90;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
@ -29,6 +29,7 @@ import org.apache.lucene.index.FieldInfos;
|
|||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
import org.apache.lucene.index.SegmentInfo;
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.DataOutput;
|
import org.apache.lucene.store.DataOutput;
|
||||||
@ -191,6 +192,7 @@ public final class Lucene90FieldInfosFormat extends FieldInfosFormat {
|
|||||||
pointIndexDimensionCount,
|
pointIndexDimensionCount,
|
||||||
pointNumBytes,
|
pointNumBytes,
|
||||||
vectorDimension,
|
vectorDimension,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
vectorDistFunc,
|
vectorDistFunc,
|
||||||
isSoftDeletesField);
|
isSoftDeletesField);
|
||||||
infos[i].checkConsistency();
|
infos[i].checkConsistency();
|
@ -18,6 +18,7 @@
|
|||||||
package org.apache.lucene.backward_codecs.lucene90;
|
package org.apache.lucene.backward_codecs.lucene90;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
@ -36,6 +37,7 @@ import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
|||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -277,6 +279,21 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
// The field does not exist or does not index vectors
|
||||||
|
return EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||||
|
VectorValues vectorValues = getVectorValues(field);
|
||||||
|
|
||||||
|
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||||
|
}
|
||||||
|
|
||||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
package org.apache.lucene.backward_codecs.lucene91;
|
package org.apache.lucene.backward_codecs.lucene91;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.CompoundFormat;
|
import org.apache.lucene.codecs.CompoundFormat;
|
||||||
import org.apache.lucene.codecs.DocValuesFormat;
|
import org.apache.lucene.codecs.DocValuesFormat;
|
||||||
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
|||||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||||
|
@ -25,6 +25,7 @@ import java.util.Objects;
|
|||||||
import java.util.SplittableRandom;
|
import java.util.SplittableRandom;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.InfoStream;
|
import org.apache.lucene.util.InfoStream;
|
||||||
@ -55,7 +56,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues vectorValues;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final Lucene91BoundsChecker bound;
|
private final Lucene91BoundsChecker bound;
|
||||||
private final HnswGraphSearcher graphSearcher;
|
private final HnswGraphSearcher<float[]> graphSearcher;
|
||||||
|
|
||||||
final Lucene91OnHeapHnswGraph hnsw;
|
final Lucene91OnHeapHnswGraph hnsw;
|
||||||
|
|
||||||
@ -101,7 +102,8 @@ public final class Lucene91HnswGraphBuilder {
|
|||||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||||
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
this.hnsw = new Lucene91OnHeapHnswGraph(maxConn, levelOfFirstNode);
|
||||||
this.graphSearcher =
|
this.graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher<>(
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
new NeighborQueue(beamWidth, true),
|
new NeighborQueue(beamWidth, true),
|
||||||
new FixedBitSet(vectorValues.size()));
|
new FixedBitSet(vectorValues.size()));
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
package org.apache.lucene.backward_codecs.lucene91;
|
package org.apache.lucene.backward_codecs.lucene91;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
@ -34,8 +35,10 @@ import org.apache.lucene.index.IndexFileNames;
|
|||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -244,6 +247,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||||||
target,
|
target,
|
||||||
k,
|
k,
|
||||||
vectorValues,
|
vectorValues,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
getGraph(fieldEntry),
|
getGraph(fieldEntry),
|
||||||
getAcceptOrds(acceptDocs, fieldEntry),
|
getAcceptOrds(acceptDocs, fieldEntry),
|
||||||
@ -265,6 +269,21 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
// The field does not exist or does not index vectors
|
||||||
|
return EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||||
|
VectorValues vectorValues = getVectorValues(field);
|
||||||
|
|
||||||
|
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||||
|
}
|
||||||
|
|
||||||
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
private OffHeapVectorValues getOffHeapVectorValues(FieldEntry fieldEntry) throws IOException {
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
@ -144,8 +144,8 @@
|
|||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||||
* information about how the segment is sorted
|
* information about how the segment is sorted
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
|
||||||
* contains metadata about the set of named fields used in the index.
|
* This contains metadata about the set of named fields used in the index.
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||||
* field names. These are used to store auxiliary information about the document, such as its
|
* field names. These are used to store auxiliary information about the document, such as its
|
||||||
@ -240,7 +240,7 @@
|
|||||||
* systems that frequently run out of file handles.</td>
|
* systems that frequently run out of file handles.</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
* <tr>
|
* <tr>
|
||||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||||
* <td>.fnm</td>
|
* <td>.fnm</td>
|
||||||
* <td>Stores information about the fields</td>
|
* <td>Stores information about the fields</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
package org.apache.lucene.backward_codecs.lucene92;
|
package org.apache.lucene.backward_codecs.lucene92;
|
||||||
|
|
||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
|
import org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.CompoundFormat;
|
import org.apache.lucene.codecs.CompoundFormat;
|
||||||
import org.apache.lucene.codecs.DocValuesFormat;
|
import org.apache.lucene.codecs.DocValuesFormat;
|
||||||
@ -32,7 +33,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
|||||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
package org.apache.lucene.backward_codecs.lucene92;
|
package org.apache.lucene.backward_codecs.lucene92;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
@ -30,8 +31,10 @@ import org.apache.lucene.index.FieldInfo;
|
|||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -237,6 +240,7 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||||||
target,
|
target,
|
||||||
k,
|
k,
|
||||||
vectorValues,
|
vectorValues,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
getGraph(fieldEntry),
|
getGraph(fieldEntry),
|
||||||
vectorValues.getAcceptOrds(acceptDocs),
|
vectorValues.getAcceptOrds(acceptDocs),
|
||||||
@ -258,6 +262,21 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
// The field does not exist or does not index vectors
|
||||||
|
return EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||||
|
VectorValues vectorValues = getVectorValues(field);
|
||||||
|
|
||||||
|
return exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||||
|
}
|
||||||
|
|
||||||
/** Get knn graph values; used for testing */
|
/** Get knn graph values; used for testing */
|
||||||
public HnswGraph getGraph(String field) throws IOException {
|
public HnswGraph getGraph(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
|
@ -144,8 +144,8 @@
|
|||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||||
* information about how the segment is sorted
|
* information about how the segment is sorted
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
* <li>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Field names}.
|
||||||
* contains metadata about the set of named fields used in the index.
|
* This contains metadata about the set of named fields used in the index.
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||||
* field names. These are used to store auxiliary information about the document, such as its
|
* field names. These are used to store auxiliary information about the document, such as its
|
||||||
@ -240,7 +240,7 @@
|
|||||||
* systems that frequently run out of file handles.</td>
|
* systems that frequently run out of file handles.</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
* <tr>
|
* <tr>
|
||||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
* <td>{@link org.apache.lucene.backward_codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
||||||
* <td>.fnm</td>
|
* <td>.fnm</td>
|
||||||
* <td>Stores information about the fields</td>
|
* <td>Stores information about the fields</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
|
@ -31,6 +31,7 @@ import org.apache.lucene.index.FieldInfo;
|
|||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
@ -148,7 +149,8 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||||||
OnHeapHnswGraph graph =
|
OnHeapHnswGraph graph =
|
||||||
offHeapVectors.size() == 0
|
offHeapVectors.size() == 0
|
||||||
? null
|
? null
|
||||||
: writeGraph(offHeapVectors, fieldInfo.getVectorSimilarityFunction());
|
: writeGraph(
|
||||||
|
offHeapVectors, VectorEncoding.FLOAT32, fieldInfo.getVectorSimilarityFunction());
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
writeMeta(
|
writeMeta(
|
||||||
fieldInfo,
|
fieldInfo,
|
||||||
@ -266,13 +268,20 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private OnHeapHnswGraph writeGraph(
|
private OnHeapHnswGraph writeGraph(
|
||||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
RandomAccessVectorValuesProducer vectorValues,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphBuilder hnswGraphBuilder =
|
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
vectorValues, similarityFunction, M, beamWidth, HnswGraphBuilder.randSeed);
|
vectorValues,
|
||||||
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
|
M,
|
||||||
|
beamWidth,
|
||||||
|
HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfos;
|
|||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
import org.apache.lucene.index.SegmentInfo;
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
@ -68,7 +69,8 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
|||||||
static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count ");
|
static final BytesRef INDEX_DIM_COUNT = new BytesRef(" index dimensional count ");
|
||||||
static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes ");
|
static final BytesRef DIM_NUM_BYTES = new BytesRef(" dimensional num bytes ");
|
||||||
static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions ");
|
static final BytesRef VECTOR_NUM_DIMS = new BytesRef(" vector number of dimensions ");
|
||||||
static final BytesRef VECTOR_SEARCH_STRATEGY = new BytesRef(" vector search strategy ");
|
static final BytesRef VECTOR_ENCODING = new BytesRef(" vector encoding ");
|
||||||
|
static final BytesRef VECTOR_SIMILARITY = new BytesRef(" vector similarity ");
|
||||||
static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes ");
|
static final BytesRef SOFT_DELETES = new BytesRef(" soft-deletes ");
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -156,8 +158,13 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
|||||||
int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch));
|
int vectorNumDimensions = Integer.parseInt(readString(VECTOR_NUM_DIMS.length, scratch));
|
||||||
|
|
||||||
SimpleTextUtil.readLine(input, scratch);
|
SimpleTextUtil.readLine(input, scratch);
|
||||||
assert StringHelper.startsWith(scratch.get(), VECTOR_SEARCH_STRATEGY);
|
assert StringHelper.startsWith(scratch.get(), VECTOR_ENCODING);
|
||||||
String scoreFunction = readString(VECTOR_SEARCH_STRATEGY.length, scratch);
|
String encoding = readString(VECTOR_ENCODING.length, scratch);
|
||||||
|
VectorEncoding vectorEncoding = vectorEncoding(encoding);
|
||||||
|
|
||||||
|
SimpleTextUtil.readLine(input, scratch);
|
||||||
|
assert StringHelper.startsWith(scratch.get(), VECTOR_SIMILARITY);
|
||||||
|
String scoreFunction = readString(VECTOR_SIMILARITY.length, scratch);
|
||||||
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
|
VectorSimilarityFunction vectorDistFunc = distanceFunction(scoreFunction);
|
||||||
|
|
||||||
SimpleTextUtil.readLine(input, scratch);
|
SimpleTextUtil.readLine(input, scratch);
|
||||||
@ -179,6 +186,7 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
|||||||
indexDimensionalCount,
|
indexDimensionalCount,
|
||||||
dimensionalNumBytes,
|
dimensionalNumBytes,
|
||||||
vectorNumDimensions,
|
vectorNumDimensions,
|
||||||
|
vectorEncoding,
|
||||||
vectorDistFunc,
|
vectorDistFunc,
|
||||||
isSoftDeletesField);
|
isSoftDeletesField);
|
||||||
}
|
}
|
||||||
@ -201,6 +209,10 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
|||||||
return DocValuesType.valueOf(dvType);
|
return DocValuesType.valueOf(dvType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public VectorEncoding vectorEncoding(String vectorEncoding) {
|
||||||
|
return VectorEncoding.valueOf(vectorEncoding);
|
||||||
|
}
|
||||||
|
|
||||||
public VectorSimilarityFunction distanceFunction(String scoreFunction) {
|
public VectorSimilarityFunction distanceFunction(String scoreFunction) {
|
||||||
return VectorSimilarityFunction.valueOf(scoreFunction);
|
return VectorSimilarityFunction.valueOf(scoreFunction);
|
||||||
}
|
}
|
||||||
@ -297,7 +309,11 @@ public class SimpleTextFieldInfosFormat extends FieldInfosFormat {
|
|||||||
SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch);
|
SimpleTextUtil.write(out, Integer.toString(fi.getVectorDimension()), scratch);
|
||||||
SimpleTextUtil.writeNewline(out);
|
SimpleTextUtil.writeNewline(out);
|
||||||
|
|
||||||
SimpleTextUtil.write(out, VECTOR_SEARCH_STRATEGY);
|
SimpleTextUtil.write(out, VECTOR_ENCODING);
|
||||||
|
SimpleTextUtil.write(out, fi.getVectorEncoding().name(), scratch);
|
||||||
|
SimpleTextUtil.writeNewline(out);
|
||||||
|
|
||||||
|
SimpleTextUtil.write(out, VECTOR_SIMILARITY);
|
||||||
SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch);
|
SimpleTextUtil.write(out, fi.getVectorSimilarityFunction().name(), scratch);
|
||||||
SimpleTextUtil.writeNewline(out);
|
SimpleTextUtil.writeNewline(out);
|
||||||
|
|
||||||
|
@ -42,6 +42,7 @@ import org.apache.lucene.store.BufferedChecksumIndexInput;
|
|||||||
import org.apache.lucene.store.ChecksumIndexInput;
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
import org.apache.lucene.store.IOContext;
|
import org.apache.lucene.store.IOContext;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.util.BitSet;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.BytesRefBuilder;
|
import org.apache.lucene.util.BytesRefBuilder;
|
||||||
@ -181,6 +182,13 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||||||
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
return new TopDocs(new TotalHits(numVisited, relation), topScoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
int numDocs = (int) acceptDocs.cost();
|
||||||
|
return search(field, target, k, BitSet.of(acceptDocs, numDocs), Integer.MAX_VALUE);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
IndexInput clone = dataIn.clone();
|
IndexInput clone = dataIn.clone();
|
||||||
|
@ -23,6 +23,7 @@ import org.apache.lucene.codecs.lucene90.tests.MockTermStateFactory;
|
|||||||
import org.apache.lucene.index.DocValuesType;
|
import org.apache.lucene.index.DocValuesType;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.ByteBuffersDataOutput;
|
import org.apache.lucene.store.ByteBuffersDataOutput;
|
||||||
import org.apache.lucene.store.ByteBuffersIndexOutput;
|
import org.apache.lucene.store.ByteBuffersIndexOutput;
|
||||||
@ -116,6 +117,7 @@ public class TestBlockWriter extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.ImpactsEnum;
|
|||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
import org.apache.lucene.index.PostingsEnum;
|
import org.apache.lucene.index.PostingsEnum;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.ByteBuffersDirectory;
|
import org.apache.lucene.store.ByteBuffersDirectory;
|
||||||
import org.apache.lucene.store.DataInput;
|
import org.apache.lucene.store.DataInput;
|
||||||
@ -203,6 +204,7 @@ public class TestSTBlockReader extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
@ -20,8 +20,12 @@ package org.apache.lucene.codecs;
|
|||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
|
|
||||||
/** Vectors' writer for a field */
|
/**
|
||||||
public abstract class KnnFieldVectorsWriter implements Accountable {
|
* Vectors' writer for a field
|
||||||
|
*
|
||||||
|
* @param <T> an array type; the type of vectors to be written
|
||||||
|
*/
|
||||||
|
public abstract class KnnFieldVectorsWriter<T> implements Accountable {
|
||||||
|
|
||||||
/** Sole constructor */
|
/** Sole constructor */
|
||||||
protected KnnFieldVectorsWriter() {}
|
protected KnnFieldVectorsWriter() {}
|
||||||
@ -30,5 +34,13 @@ public abstract class KnnFieldVectorsWriter implements Accountable {
|
|||||||
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
||||||
* increasing order.
|
* increasing order.
|
||||||
*/
|
*/
|
||||||
public abstract void addValue(int docID, float[] vectorValue) throws IOException;
|
public abstract void addValue(int docID, Object vectorValue) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Used to copy values being indexed to internal storage.
|
||||||
|
*
|
||||||
|
* @param vectorValue an array containing the vector value to add
|
||||||
|
* @return a copy of the value; a new array
|
||||||
|
*/
|
||||||
|
public abstract T copyValue(T vectorValue);
|
||||||
}
|
}
|
||||||
|
@ -18,9 +18,11 @@
|
|||||||
package org.apache.lucene.codecs;
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.codecs.lucene94.Lucene94HnswVectorsFormat;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TopDocsCollector;
|
import org.apache.lucene.search.TopDocsCollector;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
@ -76,6 +78,15 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||||||
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
||||||
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
public abstract KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the current KnnVectorsFormat version number. Indexes written using the format will be
|
||||||
|
* "stamped" with this version.
|
||||||
|
*/
|
||||||
|
public int currentVersion() {
|
||||||
|
// return the version supported by older codecs that did not override this method
|
||||||
|
return Lucene94HnswVectorsFormat.VERSION_START;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
|
* EMPTY throws an exception when written. It acts as a sentinel indicating a Codec that does not
|
||||||
* support vectors.
|
* support vectors.
|
||||||
@ -104,6 +115,12 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||||||
return TopDocsCollector.EMPTY_TOPDOCS;
|
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
return TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() {}
|
public void close() {}
|
||||||
|
|
||||||
|
@ -20,12 +20,16 @@ package org.apache.lucene.codecs;
|
|||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.search.HitQueue;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
/** Reads vectors from an index. */
|
/** Reads vectors from an index. */
|
||||||
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||||
@ -75,11 +79,39 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||||||
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
* @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null}
|
||||||
* if they are all allowed to match.
|
* if they are all allowed to match.
|
||||||
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
* @param visitedLimit the maximum number of nodes that the search is allowed to visit
|
||||||
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||||
*/
|
*/
|
||||||
public abstract TopDocs search(
|
public abstract TopDocs search(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||||
|
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||||
|
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||||
|
* larger score corresponds to a higher ranking.
|
||||||
|
*
|
||||||
|
* <p>The search is exact, guaranteeing the true k closest neighbors will be returned. Typically
|
||||||
|
* this requires an exhaustive scan of the entire index. It is intended to be used when the number
|
||||||
|
* of potential matches is limited.
|
||||||
|
*
|
||||||
|
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor, in
|
||||||
|
* order of their similarity to the query vector (decreasing scores). The {@link TotalHits}
|
||||||
|
* contains the number of documents visited during the search. If the search stopped early because
|
||||||
|
* it hit {@code visitedLimit}, it is indicated through the relation {@code
|
||||||
|
* TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO}.
|
||||||
|
*
|
||||||
|
* <p>The behavior is undefined if the given field doesn't have KNN vectors enabled on its {@link
|
||||||
|
* FieldInfo}. The return value is never {@code null}.
|
||||||
|
*
|
||||||
|
* @param field the vector field to search
|
||||||
|
* @param target the vector-valued query
|
||||||
|
* @param k the number of docs to return
|
||||||
|
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match.
|
||||||
|
* @return the k nearest neighbor documents, along with their (similarity-specific) scores.
|
||||||
|
*/
|
||||||
|
public abstract TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
* Returns an instance optimized for merging. This instance may only be consumed in the thread
|
||||||
* that called {@link #getMergeInstance()}.
|
* that called {@link #getMergeInstance()}.
|
||||||
@ -89,4 +121,67 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||||||
public KnnVectorsReader getMergeInstance() {
|
public KnnVectorsReader getMergeInstance() {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** {@link #searchExhaustively} */
|
||||||
|
protected static TopDocs exhaustiveSearch(
|
||||||
|
VectorValues vectorValues,
|
||||||
|
DocIdSetIterator acceptDocs,
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
float[] target,
|
||||||
|
int k)
|
||||||
|
throws IOException {
|
||||||
|
HitQueue queue = new HitQueue(k, true);
|
||||||
|
ScoreDoc topDoc = queue.top();
|
||||||
|
int doc;
|
||||||
|
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||||
|
int vectorDoc = vectorValues.advance(doc);
|
||||||
|
assert vectorDoc == doc;
|
||||||
|
float score = similarityFunction.compare(vectorValues.vectorValue(), target);
|
||||||
|
if (score >= topDoc.score) {
|
||||||
|
topDoc.score = score;
|
||||||
|
topDoc.doc = doc;
|
||||||
|
topDoc = queue.updateTop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** {@link #searchExhaustively} */
|
||||||
|
protected static TopDocs exhaustiveSearch(
|
||||||
|
VectorValues vectorValues,
|
||||||
|
DocIdSetIterator acceptDocs,
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
BytesRef target,
|
||||||
|
int k)
|
||||||
|
throws IOException {
|
||||||
|
HitQueue queue = new HitQueue(k, true);
|
||||||
|
ScoreDoc topDoc = queue.top();
|
||||||
|
int doc;
|
||||||
|
while ((doc = acceptDocs.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||||
|
int vectorDoc = vectorValues.advance(doc);
|
||||||
|
assert vectorDoc == doc;
|
||||||
|
float score = similarityFunction.compare(vectorValues.binaryValue(), target);
|
||||||
|
if (score >= topDoc.score) {
|
||||||
|
topDoc.score = score;
|
||||||
|
topDoc.doc = doc;
|
||||||
|
topDoc = queue.updateTop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return topDocsFromHitQueue(queue, acceptDocs.cost());
|
||||||
|
}
|
||||||
|
|
||||||
|
private static TopDocs topDocsFromHitQueue(HitQueue queue, long numHits) {
|
||||||
|
// Remove any remaining sentinel values
|
||||||
|
while (queue.size() > 0 && queue.top().score < 0) {
|
||||||
|
queue.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
||||||
|
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||||
|
topScoreDocs[i] = queue.pop();
|
||||||
|
}
|
||||||
|
|
||||||
|
TotalHits totalHits = new TotalHits(numHits, TotalHits.Relation.EQUAL_TO);
|
||||||
|
return new TopDocs(totalHits, topScoreDocs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -37,14 +37,15 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||||||
protected KnnVectorsWriter() {}
|
protected KnnVectorsWriter() {}
|
||||||
|
|
||||||
/** Add new field for indexing */
|
/** Add new field for indexing */
|
||||||
public abstract KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException;
|
public abstract KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException;
|
||||||
|
|
||||||
/** Flush all buffered data on disk * */
|
/** Flush all buffered data on disk * */
|
||||||
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
||||||
|
|
||||||
/** Write field for merging */
|
/** Write field for merging */
|
||||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
@SuppressWarnings("unchecked")
|
||||||
KnnFieldVectorsWriter writer = addField(fieldInfo);
|
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
|
||||||
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||||
for (int doc = mergedValues.nextDoc();
|
for (int doc = mergedValues.nextDoc();
|
||||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene94;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.FilterVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/** reads from byte-encoded data */
|
||||||
|
public class ExpandingVectorValues extends FilterVectorValues {
|
||||||
|
|
||||||
|
private final float[] value;
|
||||||
|
|
||||||
|
/** @param in the wrapped values */
|
||||||
|
protected ExpandingVectorValues(VectorValues in) {
|
||||||
|
super(in);
|
||||||
|
value = new float[in.dimension()];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] vectorValue() throws IOException {
|
||||||
|
BytesRef binaryValue = binaryValue();
|
||||||
|
byte[] bytes = binaryValue.bytes;
|
||||||
|
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
|
||||||
|
value[i] = bytes[j];
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
}
|
@ -32,7 +32,6 @@ import org.apache.lucene.codecs.StoredFieldsFormat;
|
|||||||
import org.apache.lucene.codecs.TermVectorsFormat;
|
import org.apache.lucene.codecs.TermVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90CompoundFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat;
|
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90LiveDocsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90NormsFormat;
|
||||||
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
import org.apache.lucene.codecs.lucene90.Lucene90PointsFormat;
|
||||||
@ -69,7 +68,7 @@ public class Lucene94Codec extends Codec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat();
|
private final TermVectorsFormat vectorsFormat = new Lucene90TermVectorsFormat();
|
||||||
private final FieldInfosFormat fieldInfosFormat = new Lucene90FieldInfosFormat();
|
private final FieldInfosFormat fieldInfosFormat = new Lucene94FieldInfosFormat();
|
||||||
private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat();
|
private final SegmentInfoFormat segmentInfosFormat = new Lucene90SegmentInfoFormat();
|
||||||
private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat();
|
private final LiveDocsFormat liveDocsFormat = new Lucene90LiveDocsFormat();
|
||||||
private final CompoundFormat compoundFormat = new Lucene90CompoundFormat();
|
private final CompoundFormat compoundFormat = new Lucene90CompoundFormat();
|
||||||
@ -100,6 +99,11 @@ public class Lucene94Codec extends Codec {
|
|||||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
return Lucene94Codec.this.getKnnVectorsFormatForField(field);
|
return Lucene94Codec.this.getKnnVectorsFormatForField(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int currentVersion() {
|
||||||
|
return Lucene94HnswVectorsFormat.VERSION_CURRENT;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
private final StoredFieldsFormat storedFieldsFormat;
|
private final StoredFieldsFormat storedFieldsFormat;
|
||||||
|
@ -0,0 +1,385 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.codecs.lucene94;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.DocValuesFormat;
|
||||||
|
import org.apache.lucene.codecs.FieldInfosFormat;
|
||||||
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
|
import org.apache.lucene.index.DocValuesType;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.IndexOptions;
|
||||||
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
import org.apache.lucene.store.DataOutput;
|
||||||
|
import org.apache.lucene.store.Directory;
|
||||||
|
import org.apache.lucene.store.IOContext;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lucene 9.0 Field Infos format.
|
||||||
|
*
|
||||||
|
* <p>Field names are stored in the field info file, with suffix <code>.fnm</code>.
|
||||||
|
*
|
||||||
|
* <p>FieldInfos (.fnm) --> Header,FieldsCount, <FieldName,FieldNumber,
|
||||||
|
* FieldBits,DocValuesBits,DocValuesGen,Attributes,DimensionCount,DimensionNumBytes>
|
||||||
|
* <sup>FieldsCount</sup>,Footer
|
||||||
|
*
|
||||||
|
* <p>Data types:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Header --> {@link CodecUtil#checkIndexHeader IndexHeader}
|
||||||
|
* <li>FieldsCount --> {@link DataOutput#writeVInt VInt}
|
||||||
|
* <li>FieldName --> {@link DataOutput#writeString String}
|
||||||
|
* <li>FieldBits, IndexOptions, DocValuesBits --> {@link DataOutput#writeByte Byte}
|
||||||
|
* <li>FieldNumber, DimensionCount, DimensionNumBytes --> {@link DataOutput#writeInt VInt}
|
||||||
|
* <li>Attributes --> {@link DataOutput#writeMapOfStrings Map<String,String>}
|
||||||
|
* <li>DocValuesGen --> {@link DataOutput#writeLong(long) Int64}
|
||||||
|
* <li>Footer --> {@link CodecUtil#writeFooter CodecFooter}
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* Field Descriptions:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>FieldsCount: the number of fields in this file.
|
||||||
|
* <li>FieldName: name of the field as a UTF-8 String.
|
||||||
|
* <li>FieldNumber: the field's number. Note that unlike previous versions of Lucene, the fields
|
||||||
|
* are not numbered implicitly by their order in the file, instead explicitly.
|
||||||
|
* <li>FieldBits: a byte containing field options.
|
||||||
|
* <ul>
|
||||||
|
* <li>The low order bit (0x1) is one for fields that have term vectors stored, and zero for
|
||||||
|
* fields without term vectors.
|
||||||
|
* <li>If the second lowest order-bit is set (0x2), norms are omitted for the indexed field.
|
||||||
|
* <li>If the third lowest-order bit is set (0x4), payloads are stored for the indexed
|
||||||
|
* field.
|
||||||
|
* </ul>
|
||||||
|
* <li>IndexOptions: a byte containing index options.
|
||||||
|
* <ul>
|
||||||
|
* <li>0: not indexed
|
||||||
|
* <li>1: indexed as DOCS_ONLY
|
||||||
|
* <li>2: indexed as DOCS_AND_FREQS
|
||||||
|
* <li>3: indexed as DOCS_AND_FREQS_AND_POSITIONS
|
||||||
|
* <li>4: indexed as DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS
|
||||||
|
* </ul>
|
||||||
|
* <li>DocValuesBits: a byte containing per-document value types. The type recorded as two
|
||||||
|
* four-bit integers, with the high-order bits representing <code>norms</code> options, and
|
||||||
|
* the low-order bits representing {@code DocValues} options. Each four-bit integer can be
|
||||||
|
* decoded as such:
|
||||||
|
* <ul>
|
||||||
|
* <li>0: no DocValues for this field.
|
||||||
|
* <li>1: NumericDocValues. ({@link DocValuesType#NUMERIC})
|
||||||
|
* <li>2: BinaryDocValues. ({@code DocValuesType#BINARY})
|
||||||
|
* <li>3: SortedDocValues. ({@code DocValuesType#SORTED})
|
||||||
|
* </ul>
|
||||||
|
* <li>DocValuesGen is the generation count of the field's DocValues. If this is -1, there are no
|
||||||
|
* DocValues updates to that field. Anything above zero means there are updates stored by
|
||||||
|
* {@link DocValuesFormat}.
|
||||||
|
* <li>Attributes: a key-value map of codec-private attributes.
|
||||||
|
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
|
||||||
|
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
|
||||||
|
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
|
||||||
|
* <li>VectorEncoding: a byte containing the encoding of vector values:
|
||||||
|
* <ul>
|
||||||
|
* <li>0: BYTE. Samples are stored as signed bytes
|
||||||
|
* <li>1: FLOAT32. Samples are stored in IEEE 32-bit floating point format.
|
||||||
|
* </ul>
|
||||||
|
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
|
||||||
|
* calculation.
|
||||||
|
* <ul>
|
||||||
|
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
|
||||||
|
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
|
||||||
|
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
|
||||||
|
* </ul>
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class Lucene94FieldInfosFormat extends FieldInfosFormat {
|
||||||
|
|
||||||
|
/** Sole constructor. */
|
||||||
|
public Lucene94FieldInfosFormat() {}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FieldInfos read(
|
||||||
|
Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext context)
|
||||||
|
throws IOException {
|
||||||
|
final String fileName =
|
||||||
|
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
|
||||||
|
try (ChecksumIndexInput input = directory.openChecksumInput(fileName, context)) {
|
||||||
|
Throwable priorE = null;
|
||||||
|
FieldInfo[] infos = null;
|
||||||
|
try {
|
||||||
|
CodecUtil.checkIndexHeader(
|
||||||
|
input,
|
||||||
|
Lucene94FieldInfosFormat.CODEC_NAME,
|
||||||
|
Lucene94FieldInfosFormat.FORMAT_START,
|
||||||
|
Lucene94FieldInfosFormat.FORMAT_CURRENT,
|
||||||
|
segmentInfo.getId(),
|
||||||
|
segmentSuffix);
|
||||||
|
|
||||||
|
final int size = input.readVInt(); // read in the size
|
||||||
|
infos = new FieldInfo[size];
|
||||||
|
|
||||||
|
// previous field's attribute map, we share when possible:
|
||||||
|
Map<String, String> lastAttributes = Collections.emptyMap();
|
||||||
|
|
||||||
|
for (int i = 0; i < size; i++) {
|
||||||
|
String name = input.readString();
|
||||||
|
final int fieldNumber = input.readVInt();
|
||||||
|
if (fieldNumber < 0) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"invalid field number for field: " + name + ", fieldNumber=" + fieldNumber, input);
|
||||||
|
}
|
||||||
|
byte bits = input.readByte();
|
||||||
|
boolean storeTermVector = (bits & STORE_TERMVECTOR) != 0;
|
||||||
|
boolean omitNorms = (bits & OMIT_NORMS) != 0;
|
||||||
|
boolean storePayloads = (bits & STORE_PAYLOADS) != 0;
|
||||||
|
boolean isSoftDeletesField = (bits & SOFT_DELETES_FIELD) != 0;
|
||||||
|
|
||||||
|
final IndexOptions indexOptions = getIndexOptions(input, input.readByte());
|
||||||
|
|
||||||
|
// DV Types are packed in one byte
|
||||||
|
final DocValuesType docValuesType = getDocValuesType(input, input.readByte());
|
||||||
|
final long dvGen = input.readLong();
|
||||||
|
Map<String, String> attributes = input.readMapOfStrings();
|
||||||
|
// just use the last field's map if its the same
|
||||||
|
if (attributes.equals(lastAttributes)) {
|
||||||
|
attributes = lastAttributes;
|
||||||
|
}
|
||||||
|
lastAttributes = attributes;
|
||||||
|
int pointDataDimensionCount = input.readVInt();
|
||||||
|
int pointNumBytes;
|
||||||
|
int pointIndexDimensionCount = pointDataDimensionCount;
|
||||||
|
if (pointDataDimensionCount != 0) {
|
||||||
|
pointIndexDimensionCount = input.readVInt();
|
||||||
|
pointNumBytes = input.readVInt();
|
||||||
|
} else {
|
||||||
|
pointNumBytes = 0;
|
||||||
|
}
|
||||||
|
final int vectorDimension = input.readVInt();
|
||||||
|
final VectorEncoding vectorEncoding = getVectorEncoding(input, input.readByte());
|
||||||
|
final VectorSimilarityFunction vectorDistFunc = getDistFunc(input, input.readByte());
|
||||||
|
|
||||||
|
try {
|
||||||
|
infos[i] =
|
||||||
|
new FieldInfo(
|
||||||
|
name,
|
||||||
|
fieldNumber,
|
||||||
|
storeTermVector,
|
||||||
|
omitNorms,
|
||||||
|
storePayloads,
|
||||||
|
indexOptions,
|
||||||
|
docValuesType,
|
||||||
|
dvGen,
|
||||||
|
attributes,
|
||||||
|
pointDataDimensionCount,
|
||||||
|
pointIndexDimensionCount,
|
||||||
|
pointNumBytes,
|
||||||
|
vectorDimension,
|
||||||
|
vectorEncoding,
|
||||||
|
vectorDistFunc,
|
||||||
|
isSoftDeletesField);
|
||||||
|
infos[i].checkConsistency();
|
||||||
|
} catch (IllegalStateException e) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"invalid fieldinfo for field: " + name + ", fieldNumber=" + fieldNumber, input, e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (Throwable exception) {
|
||||||
|
priorE = exception;
|
||||||
|
} finally {
|
||||||
|
CodecUtil.checkFooter(input, priorE);
|
||||||
|
}
|
||||||
|
return new FieldInfos(infos);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static {
|
||||||
|
// We "mirror" DocValues enum values with the constants below; let's try to ensure if we add a
|
||||||
|
// new DocValuesType while this format is
|
||||||
|
// still used for writing, we remember to fix this encoding:
|
||||||
|
assert DocValuesType.values().length == 6;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte docValuesByte(DocValuesType type) {
|
||||||
|
switch (type) {
|
||||||
|
case NONE:
|
||||||
|
return 0;
|
||||||
|
case NUMERIC:
|
||||||
|
return 1;
|
||||||
|
case BINARY:
|
||||||
|
return 2;
|
||||||
|
case SORTED:
|
||||||
|
return 3;
|
||||||
|
case SORTED_SET:
|
||||||
|
return 4;
|
||||||
|
case SORTED_NUMERIC:
|
||||||
|
return 5;
|
||||||
|
default:
|
||||||
|
// BUG
|
||||||
|
throw new AssertionError("unhandled DocValuesType: " + type);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static DocValuesType getDocValuesType(IndexInput input, byte b) throws IOException {
|
||||||
|
switch (b) {
|
||||||
|
case 0:
|
||||||
|
return DocValuesType.NONE;
|
||||||
|
case 1:
|
||||||
|
return DocValuesType.NUMERIC;
|
||||||
|
case 2:
|
||||||
|
return DocValuesType.BINARY;
|
||||||
|
case 3:
|
||||||
|
return DocValuesType.SORTED;
|
||||||
|
case 4:
|
||||||
|
return DocValuesType.SORTED_SET;
|
||||||
|
case 5:
|
||||||
|
return DocValuesType.SORTED_NUMERIC;
|
||||||
|
default:
|
||||||
|
throw new CorruptIndexException("invalid docvalues byte: " + b, input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static VectorEncoding getVectorEncoding(IndexInput input, byte b) throws IOException {
|
||||||
|
if (b < 0 || b >= VectorEncoding.values().length) {
|
||||||
|
throw new CorruptIndexException("invalid vector encoding: " + b, input);
|
||||||
|
}
|
||||||
|
return VectorEncoding.values()[b];
|
||||||
|
}
|
||||||
|
|
||||||
|
private static VectorSimilarityFunction getDistFunc(IndexInput input, byte b) throws IOException {
|
||||||
|
if (b < 0 || b >= VectorSimilarityFunction.values().length) {
|
||||||
|
throw new CorruptIndexException("invalid distance function: " + b, input);
|
||||||
|
}
|
||||||
|
return VectorSimilarityFunction.values()[b];
|
||||||
|
}
|
||||||
|
|
||||||
|
static {
|
||||||
|
// We "mirror" IndexOptions enum values with the constants below; let's try to ensure if we add
|
||||||
|
// a new IndexOption while this format is
|
||||||
|
// still used for writing, we remember to fix this encoding:
|
||||||
|
assert IndexOptions.values().length == 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static byte indexOptionsByte(IndexOptions indexOptions) {
|
||||||
|
switch (indexOptions) {
|
||||||
|
case NONE:
|
||||||
|
return 0;
|
||||||
|
case DOCS:
|
||||||
|
return 1;
|
||||||
|
case DOCS_AND_FREQS:
|
||||||
|
return 2;
|
||||||
|
case DOCS_AND_FREQS_AND_POSITIONS:
|
||||||
|
return 3;
|
||||||
|
case DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS:
|
||||||
|
return 4;
|
||||||
|
default:
|
||||||
|
// BUG:
|
||||||
|
throw new AssertionError("unhandled IndexOptions: " + indexOptions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IndexOptions getIndexOptions(IndexInput input, byte b) throws IOException {
|
||||||
|
switch (b) {
|
||||||
|
case 0:
|
||||||
|
return IndexOptions.NONE;
|
||||||
|
case 1:
|
||||||
|
return IndexOptions.DOCS;
|
||||||
|
case 2:
|
||||||
|
return IndexOptions.DOCS_AND_FREQS;
|
||||||
|
case 3:
|
||||||
|
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS;
|
||||||
|
case 4:
|
||||||
|
return IndexOptions.DOCS_AND_FREQS_AND_POSITIONS_AND_OFFSETS;
|
||||||
|
default:
|
||||||
|
// BUG
|
||||||
|
throw new CorruptIndexException("invalid IndexOptions byte: " + b, input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void write(
|
||||||
|
Directory directory,
|
||||||
|
SegmentInfo segmentInfo,
|
||||||
|
String segmentSuffix,
|
||||||
|
FieldInfos infos,
|
||||||
|
IOContext context)
|
||||||
|
throws IOException {
|
||||||
|
final String fileName =
|
||||||
|
IndexFileNames.segmentFileName(segmentInfo.name, segmentSuffix, EXTENSION);
|
||||||
|
try (IndexOutput output = directory.createOutput(fileName, context)) {
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
output,
|
||||||
|
Lucene94FieldInfosFormat.CODEC_NAME,
|
||||||
|
Lucene94FieldInfosFormat.FORMAT_CURRENT,
|
||||||
|
segmentInfo.getId(),
|
||||||
|
segmentSuffix);
|
||||||
|
output.writeVInt(infos.size());
|
||||||
|
for (FieldInfo fi : infos) {
|
||||||
|
fi.checkConsistency();
|
||||||
|
|
||||||
|
output.writeString(fi.name);
|
||||||
|
output.writeVInt(fi.number);
|
||||||
|
|
||||||
|
byte bits = 0x0;
|
||||||
|
if (fi.hasVectors()) bits |= STORE_TERMVECTOR;
|
||||||
|
if (fi.omitsNorms()) bits |= OMIT_NORMS;
|
||||||
|
if (fi.hasPayloads()) bits |= STORE_PAYLOADS;
|
||||||
|
if (fi.isSoftDeletesField()) bits |= SOFT_DELETES_FIELD;
|
||||||
|
output.writeByte(bits);
|
||||||
|
|
||||||
|
output.writeByte(indexOptionsByte(fi.getIndexOptions()));
|
||||||
|
|
||||||
|
// pack the DV type and hasNorms in one byte
|
||||||
|
output.writeByte(docValuesByte(fi.getDocValuesType()));
|
||||||
|
output.writeLong(fi.getDocValuesGen());
|
||||||
|
output.writeMapOfStrings(fi.attributes());
|
||||||
|
output.writeVInt(fi.getPointDimensionCount());
|
||||||
|
if (fi.getPointDimensionCount() != 0) {
|
||||||
|
output.writeVInt(fi.getPointIndexDimensionCount());
|
||||||
|
output.writeVInt(fi.getPointNumBytes());
|
||||||
|
}
|
||||||
|
output.writeVInt(fi.getVectorDimension());
|
||||||
|
output.writeByte((byte) fi.getVectorEncoding().ordinal());
|
||||||
|
output.writeByte((byte) fi.getVectorSimilarityFunction().ordinal());
|
||||||
|
}
|
||||||
|
CodecUtil.writeFooter(output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Extension of field infos */
|
||||||
|
static final String EXTENSION = "fnm";
|
||||||
|
|
||||||
|
// Codec header
|
||||||
|
static final String CODEC_NAME = "Lucene90FieldInfos";
|
||||||
|
static final int FORMAT_START = 0;
|
||||||
|
static final int FORMAT_CURRENT = FORMAT_START;
|
||||||
|
|
||||||
|
// Field flags
|
||||||
|
static final byte STORE_TERMVECTOR = 0x1;
|
||||||
|
static final byte OMIT_NORMS = 0x2;
|
||||||
|
static final byte STORE_PAYLOADS = 0x4;
|
||||||
|
static final byte SOFT_DELETES_FIELD = 0x8;
|
||||||
|
}
|
@ -38,8 +38,9 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||||||
* <p>For each field:
|
* <p>For each field:
|
||||||
*
|
*
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li>Floating-point vector data ordered by field, document ordinal, and vector dimension. The
|
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
|
||||||
* floats are stored in little-endian byte order
|
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
|
||||||
|
* sample is stored as an IEEE float in little-endian byte order.
|
||||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
||||||
* note that only in sparse case
|
* note that only in sparse case
|
||||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||||
@ -89,7 +90,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||||||
* <ul>
|
* <ul>
|
||||||
* <li><b>[int]</b> the number of nodes on this level
|
* <li><b>[int]</b> the number of nodes on this level
|
||||||
* <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as
|
* <li><b>array[int]</b> for levels greater than 0 list of nodes on this level, stored as
|
||||||
* the the level 0th nodes ordinals.
|
* the level 0th nodes' ordinals.
|
||||||
* </ul>
|
* </ul>
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
@ -104,8 +105,8 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
|
|||||||
static final String VECTOR_DATA_EXTENSION = "vec";
|
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||||
static final String VECTOR_INDEX_EXTENSION = "vex";
|
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||||
|
|
||||||
static final int VERSION_START = 0;
|
public static final int VERSION_START = 0;
|
||||||
static final int VERSION_CURRENT = VERSION_START;
|
public static final int VERSION_CURRENT = 1;
|
||||||
|
|
||||||
/** Default number of maximum connections per node */
|
/** Default number of maximum connections per node */
|
||||||
public static final int DEFAULT_MAX_CONN = 16;
|
public static final int DEFAULT_MAX_CONN = 16;
|
||||||
@ -156,6 +157,11 @@ public final class Lucene94HnswVectorsFormat extends KnnVectorsFormat {
|
|||||||
return new Lucene94HnswVectorsReader(state);
|
return new Lucene94HnswVectorsReader(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int currentVersion() {
|
||||||
|
return VERSION_CURRENT;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn="
|
return "Lucene94HnswVectorsFormat(name=Lucene94HnswVectorsFormat, maxConn="
|
||||||
|
@ -18,6 +18,8 @@
|
|||||||
package org.apache.lucene.codecs.lucene94;
|
package org.apache.lucene.codecs.lucene94;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.search.TopDocsCollector.EMPTY_TOPDOCS;
|
||||||
|
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
@ -30,8 +32,10 @@ import org.apache.lucene.index.FieldInfo;
|
|||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -169,16 +173,23 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
+ fieldEntry.dimension);
|
+ fieldEntry.dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
long numBytes = (long) fieldEntry.size() * dimension * Float.BYTES;
|
int byteSize =
|
||||||
|
switch (info.getVectorEncoding()) {
|
||||||
|
case BYTE -> Byte.BYTES;
|
||||||
|
case FLOAT32 -> Float.BYTES;
|
||||||
|
};
|
||||||
|
int numBytes = fieldEntry.size * dimension * byteSize;
|
||||||
if (numBytes != fieldEntry.vectorDataLength) {
|
if (numBytes != fieldEntry.vectorDataLength) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Vector data length "
|
"Vector data length "
|
||||||
+ fieldEntry.vectorDataLength
|
+ fieldEntry.vectorDataLength
|
||||||
+ " not matching size="
|
+ " not matching size="
|
||||||
+ fieldEntry.size()
|
+ fieldEntry.size
|
||||||
+ " * dim="
|
+ " * dim="
|
||||||
+ dimension
|
+ dimension
|
||||||
+ " * 4 = "
|
+ " * byteSize="
|
||||||
|
+ byteSize
|
||||||
|
+ " = "
|
||||||
+ numBytes);
|
+ numBytes);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,9 +204,18 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
return VectorSimilarityFunction.values()[similarityFunctionId];
|
return VectorSimilarityFunction.values()[similarityFunctionId];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||||
|
int encodingId = input.readInt();
|
||||||
|
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||||
|
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||||
|
}
|
||||||
|
return VectorEncoding.values()[encodingId];
|
||||||
|
}
|
||||||
|
|
||||||
private FieldEntry readField(IndexInput input) throws IOException {
|
private FieldEntry readField(IndexInput input) throws IOException {
|
||||||
|
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||||
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||||
return new FieldEntry(input, similarityFunction);
|
return new FieldEntry(input, vectorEncoding, similarityFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -216,7 +236,12 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||||
|
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
return new ExpandingVectorValues(values);
|
||||||
|
} else {
|
||||||
|
return values;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -237,6 +262,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
target,
|
target,
|
||||||
k,
|
k,
|
||||||
vectorValues,
|
vectorValues,
|
||||||
|
fieldEntry.vectorEncoding,
|
||||||
fieldEntry.similarityFunction,
|
fieldEntry.similarityFunction,
|
||||||
getGraph(fieldEntry),
|
getGraph(fieldEntry),
|
||||||
vectorValues.getAcceptOrds(acceptDocs),
|
vectorValues.getAcceptOrds(acceptDocs),
|
||||||
@ -258,6 +284,25 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
return new TopDocs(new TotalHits(results.visitedCount(), relation), scoreDocs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
// The field does not exist or does not index vectors
|
||||||
|
return EMPTY_TOPDOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorSimilarityFunction similarityFunction = fieldEntry.similarityFunction;
|
||||||
|
VectorValues vectorValues = getVectorValues(field);
|
||||||
|
|
||||||
|
return switch (fieldEntry.vectorEncoding) {
|
||||||
|
case BYTE -> exhaustiveSearch(
|
||||||
|
vectorValues, acceptDocs, similarityFunction, toBytesRef(target), k);
|
||||||
|
case FLOAT32 -> exhaustiveSearch(vectorValues, acceptDocs, similarityFunction, target, k);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
/** Get knn graph values; used for testing */
|
/** Get knn graph values; used for testing */
|
||||||
public HnswGraph getGraph(String field) throws IOException {
|
public HnswGraph getGraph(String field) throws IOException {
|
||||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||||
@ -286,6 +331,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
static class FieldEntry {
|
static class FieldEntry {
|
||||||
|
|
||||||
final VectorSimilarityFunction similarityFunction;
|
final VectorSimilarityFunction similarityFunction;
|
||||||
|
final VectorEncoding vectorEncoding;
|
||||||
final long vectorDataOffset;
|
final long vectorDataOffset;
|
||||||
final long vectorDataLength;
|
final long vectorDataLength;
|
||||||
final long vectorIndexOffset;
|
final long vectorIndexOffset;
|
||||||
@ -315,8 +361,13 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||||||
final DirectMonotonicReader.Meta meta;
|
final DirectMonotonicReader.Meta meta;
|
||||||
final long addressesLength;
|
final long addressesLength;
|
||||||
|
|
||||||
FieldEntry(IndexInput input, VectorSimilarityFunction similarityFunction) throws IOException {
|
FieldEntry(
|
||||||
|
IndexInput input,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction)
|
||||||
|
throws IOException {
|
||||||
this.similarityFunction = similarityFunction;
|
this.similarityFunction = similarityFunction;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
vectorDataOffset = input.readVLong();
|
vectorDataOffset = input.readVLong();
|
||||||
vectorDataLength = input.readVLong();
|
vectorDataLength = input.readVLong();
|
||||||
vectorIndexOffset = input.readVLong();
|
vectorIndexOffset = input.readVLong();
|
||||||
|
@ -65,7 +65,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
private final int M;
|
private final int M;
|
||||||
private final int beamWidth;
|
private final int beamWidth;
|
||||||
|
|
||||||
private final List<FieldWriter> fields = new ArrayList<>();
|
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||||
private boolean finished;
|
private boolean finished;
|
||||||
|
|
||||||
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
|
Lucene94HnswVectorsWriter(SegmentWriteState state, int M, int beamWidth) throws IOException {
|
||||||
@ -121,15 +121,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
FieldWriter newField = new FieldWriter(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
FieldWriter<?> newField =
|
||||||
|
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
||||||
fields.add(newField);
|
fields.add(newField);
|
||||||
return newField;
|
return newField;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
for (FieldWriter field : fields) {
|
for (FieldWriter<?> field : fields) {
|
||||||
if (sortMap == null) {
|
if (sortMap == null) {
|
||||||
writeField(field, maxDoc);
|
writeField(field, maxDoc);
|
||||||
} else {
|
} else {
|
||||||
@ -159,22 +160,20 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
long total = 0;
|
long total = 0;
|
||||||
for (FieldWriter field : fields) {
|
for (FieldWriter<?> field : fields) {
|
||||||
total += field.ramBytesUsed();
|
total += field.ramBytesUsed();
|
||||||
}
|
}
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
|
private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
|
||||||
// write vector values
|
// write vector values
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
final ByteBuffer buffer =
|
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
case BYTE -> writeByteVectors(fieldData);
|
||||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
||||||
for (float[] vector : fieldData.vectors) {
|
|
||||||
buffer.asFloatBuffer().put(vector);
|
|
||||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
|
||||||
}
|
}
|
||||||
|
;
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
// write graph
|
// write graph
|
||||||
@ -194,7 +193,24 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
graph);
|
graph);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
|
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||||
|
final ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||||
|
for (Object v : fieldData.vectors) {
|
||||||
|
buffer.asFloatBuffer().put((float[]) v);
|
||||||
|
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||||
|
for (Object v : fieldData.vectors) {
|
||||||
|
BytesRef vector = (BytesRef) v;
|
||||||
|
vectorData.writeBytes(vector.bytes, vector.offset, vector.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
final int[] docIdOffsets = new int[sortMap.size()];
|
final int[] docIdOffsets = new int[sortMap.size()];
|
||||||
int offset = 1; // 0 means no vector for this (field, document)
|
int offset = 1; // 0 means no vector for this (field, document)
|
||||||
@ -221,15 +237,11 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write vector values
|
// write vector values
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
long vectorDataOffset =
|
||||||
final ByteBuffer buffer =
|
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
|
||||||
final BytesRef binaryValue = new BytesRef(buffer.array());
|
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
|
||||||
for (int ordinal : ordMap) {
|
};
|
||||||
float[] vector = fieldData.vectors.get(ordinal);
|
|
||||||
buffer.asFloatBuffer().put(vector);
|
|
||||||
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
|
||||||
}
|
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
// write graph
|
// write graph
|
||||||
@ -249,6 +261,29 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
mockGraph);
|
mockGraph);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
|
||||||
|
throws IOException {
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
final ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
final BytesRef binaryValue = new BytesRef(buffer.array());
|
||||||
|
for (int ordinal : ordMap) {
|
||||||
|
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||||
|
buffer.asFloatBuffer().put(vector);
|
||||||
|
vectorData.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
|
}
|
||||||
|
return vectorDataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
for (int ordinal : ordMap) {
|
||||||
|
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||||
|
vectorData.writeBytes(vector, 0, vector.length);
|
||||||
|
}
|
||||||
|
return vectorDataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
// reconstruct graph substituting old ordinals with new ordinals
|
// reconstruct graph substituting old ordinals with new ordinals
|
||||||
private HnswGraph reconstructAndWriteGraph(
|
private HnswGraph reconstructAndWriteGraph(
|
||||||
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
|
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap) throws IOException {
|
||||||
@ -354,7 +389,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
// write the vector data to a temporary file
|
// write the vector data to a temporary file
|
||||||
DocsWithFieldSet docsWithField = writeVectorData(tempVectorData, vectors);
|
DocsWithFieldSet docsWithField =
|
||||||
|
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
||||||
CodecUtil.writeFooter(tempVectorData);
|
CodecUtil.writeFooter(tempVectorData);
|
||||||
IOUtils.close(tempVectorData);
|
IOUtils.close(tempVectorData);
|
||||||
|
|
||||||
@ -365,21 +401,22 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
||||||
CodecUtil.retrieveChecksum(vectorDataInput);
|
CodecUtil.retrieveChecksum(vectorDataInput);
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||||
// build the graph using the temporary vector data
|
// build the graph using the temporary vector data
|
||||||
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||||
// doesn't need to know docIds
|
// doesn't need to know docIds
|
||||||
// TODO: separate random access vector values from DocIdSetIterator?
|
// TODO: separate random access vector values from DocIdSetIterator?
|
||||||
|
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||||
OffHeapVectorValues offHeapVectors =
|
OffHeapVectorValues offHeapVectors =
|
||||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput);
|
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
||||||
OnHeapHnswGraph graph = null;
|
OnHeapHnswGraph graph = null;
|
||||||
if (offHeapVectors.size() != 0) {
|
if (offHeapVectors.size() != 0) {
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphBuilder hnswGraphBuilder =
|
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
offHeapVectors,
|
offHeapVectors,
|
||||||
|
fieldInfo.getVectorEncoding(),
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
M,
|
M,
|
||||||
beamWidth,
|
beamWidth,
|
||||||
@ -451,6 +488,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
HnswGraph graph)
|
HnswGraph graph)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
meta.writeInt(field.number);
|
meta.writeInt(field.number);
|
||||||
|
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||||
meta.writeVLong(vectorDataOffset);
|
meta.writeVLong(vectorDataOffset);
|
||||||
meta.writeVLong(vectorDataLength);
|
meta.writeVLong(vectorDataLength);
|
||||||
@ -520,13 +558,13 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
/**
|
/**
|
||||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||||
*/
|
*/
|
||||||
private static DocsWithFieldSet writeVectorData(IndexOutput output, VectorValues vectors)
|
private static DocsWithFieldSet writeVectorData(
|
||||||
throws IOException {
|
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
||||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||||
// write vector
|
// write vector
|
||||||
BytesRef binaryValue = vectors.binaryValue();
|
BytesRef binaryValue = vectors.binaryValue();
|
||||||
assert binaryValue.length == vectors.dimension() * Float.BYTES;
|
assert binaryValue.length == vectors.dimension() * scalarSize;
|
||||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
docsWithField.add(docV);
|
docsWithField.add(docV);
|
||||||
}
|
}
|
||||||
@ -538,54 +576,69 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
IOUtils.close(meta, vectorData, vectorIndex);
|
IOUtils.close(meta, vectorData, vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||||
private final FieldInfo fieldInfo;
|
private final FieldInfo fieldInfo;
|
||||||
private final int dim;
|
private final int dim;
|
||||||
private final DocsWithFieldSet docsWithField;
|
private final DocsWithFieldSet docsWithField;
|
||||||
private final List<float[]> vectors;
|
private final List<T> vectors;
|
||||||
private final RAVectorValues raVectorValues;
|
private final RAVectorValues<T> raVectorValues;
|
||||||
private final HnswGraphBuilder hnswGraphBuilder;
|
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||||
|
|
||||||
private int lastDocID = -1;
|
private int lastDocID = -1;
|
||||||
private int node = 0;
|
private int node = 0;
|
||||||
|
|
||||||
|
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||||
|
throws IOException {
|
||||||
|
int dim = fieldInfo.getVectorDimension();
|
||||||
|
return switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||||
|
@Override
|
||||||
|
public BytesRef copyValue(BytesRef value) {
|
||||||
|
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||||
|
@Override
|
||||||
|
public float[] copyValue(float[] value) {
|
||||||
|
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.fieldInfo = fieldInfo;
|
this.fieldInfo = fieldInfo;
|
||||||
this.dim = fieldInfo.getVectorDimension();
|
this.dim = fieldInfo.getVectorDimension();
|
||||||
this.docsWithField = new DocsWithFieldSet();
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
vectors = new ArrayList<>();
|
vectors = new ArrayList<>();
|
||||||
raVectorValues = new RAVectorValues(vectors, dim);
|
raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||||
hnswGraphBuilder =
|
hnswGraphBuilder =
|
||||||
new HnswGraphBuilder(
|
(HnswGraphBuilder<T>)
|
||||||
() -> raVectorValues,
|
HnswGraphBuilder.create(
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
() -> raVectorValues,
|
||||||
M,
|
fieldInfo.getVectorEncoding(),
|
||||||
beamWidth,
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
HnswGraphBuilder.randSeed);
|
M,
|
||||||
|
beamWidth,
|
||||||
|
HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(infoStream);
|
hnswGraphBuilder.setInfoStream(infoStream);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addValue(int docID, float[] vectorValue) throws IOException {
|
@SuppressWarnings("unchecked")
|
||||||
|
public void addValue(int docID, Object value) throws IOException {
|
||||||
if (docID == lastDocID) {
|
if (docID == lastDocID) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"VectorValuesField \""
|
"VectorValuesField \""
|
||||||
+ fieldInfo.name
|
+ fieldInfo.name
|
||||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
}
|
}
|
||||||
if (vectorValue.length != dim) {
|
T vectorValue = (T) value;
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Attempt to index a vector of dimension "
|
|
||||||
+ vectorValue.length
|
|
||||||
+ " but \""
|
|
||||||
+ fieldInfo.name
|
|
||||||
+ "\" has dimension "
|
|
||||||
+ dim);
|
|
||||||
}
|
|
||||||
assert docID > lastDocID;
|
assert docID > lastDocID;
|
||||||
docsWithField.add(docID);
|
docsWithField.add(docID);
|
||||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
vectors.add(copyValue(vectorValue));
|
||||||
if (node > 0) {
|
if (node > 0) {
|
||||||
// start at node 1! node 0 is added implicitly, in the constructor
|
// start at node 1! node 0 is added implicitly, in the constructor
|
||||||
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
hnswGraphBuilder.addGraphNode(node, vectorValue);
|
||||||
@ -608,16 +661,16 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
return docsWithField.ramBytesUsed()
|
return docsWithField.ramBytesUsed()
|
||||||
+ vectors.size()
|
+ vectors.size()
|
||||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||||
+ vectors.size() * vectors.get(0).length * Float.BYTES
|
+ vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize
|
||||||
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class RAVectorValues implements RandomAccessVectorValues {
|
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
||||||
private final List<float[]> vectors;
|
private final List<T> vectors;
|
||||||
private final int dim;
|
private final int dim;
|
||||||
|
|
||||||
RAVectorValues(List<float[]> vectors, int dim) {
|
RAVectorValues(List<T> vectors, int dim) {
|
||||||
this.vectors = vectors;
|
this.vectors = vectors;
|
||||||
this.dim = dim;
|
this.dim = dim;
|
||||||
}
|
}
|
||||||
@ -634,12 +687,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue(int targetOrd) throws IOException {
|
public float[] vectorValue(int targetOrd) throws IOException {
|
||||||
return vectors.get(targetOrd);
|
return (float[]) vectors.get(targetOrd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
return (BytesRef) vectors.get(targetOrd);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -41,11 +41,11 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
protected final int byteSize;
|
protected final int byteSize;
|
||||||
protected final float[] value;
|
protected final float[] value;
|
||||||
|
|
||||||
OffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
OffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
this.dimension = dimension;
|
this.dimension = dimension;
|
||||||
this.size = size;
|
this.size = size;
|
||||||
this.slice = slice;
|
this.slice = slice;
|
||||||
byteSize = Float.BYTES * dimension;
|
this.byteSize = byteSize;
|
||||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||||
value = new float[dimension];
|
value = new float[dimension];
|
||||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
@ -93,10 +93,16 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
}
|
}
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
int byteSize =
|
||||||
|
switch (fieldEntry.vectorEncoding) {
|
||||||
|
case BYTE -> fieldEntry.dimension;
|
||||||
|
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
|
||||||
|
};
|
||||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||||
return new DenseOffHeapVectorValues(fieldEntry.dimension, fieldEntry.size, bytesSlice);
|
return new DenseOffHeapVectorValues(
|
||||||
|
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||||
} else {
|
} else {
|
||||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice);
|
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,8 +112,8 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
|
||||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
super(dimension, size, slice);
|
super(dimension, size, slice, byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -145,7 +151,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -167,10 +173,13 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
|
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
|
||||||
|
|
||||||
public SparseOffHeapVectorValues(
|
public SparseOffHeapVectorValues(
|
||||||
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput dataIn, IndexInput slice)
|
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||||
|
IndexInput dataIn,
|
||||||
|
IndexInput slice,
|
||||||
|
int byteSize)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
super(fieldEntry.dimension, fieldEntry.size, slice);
|
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||||
this.fieldEntry = fieldEntry;
|
this.fieldEntry = fieldEntry;
|
||||||
final RandomAccessInput addressesData =
|
final RandomAccessInput addressesData =
|
||||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||||
@ -218,7 +227,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -248,7 +257,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||||||
private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
|
private static class EmptyOffHeapVectorValues extends OffHeapVectorValues {
|
||||||
|
|
||||||
public EmptyOffHeapVectorValues(int dimension) {
|
public EmptyOffHeapVectorValues(int dimension) {
|
||||||
super(dimension, 0, null);
|
super(dimension, 0, null, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
@ -144,7 +144,7 @@
|
|||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90SegmentInfoFormat Segment info}. This
|
||||||
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
* contains metadata about a segment, such as the number of documents, what files it uses, and
|
||||||
* information about how the segment is sorted
|
* information about how the segment is sorted
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Field names}. This
|
* <li>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Field names}. This
|
||||||
* contains metadata about the set of named fields used in the index.
|
* contains metadata about the set of named fields used in the index.
|
||||||
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
* <li>{@link org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat Stored Field values}.
|
||||||
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
* This contains, for each document, a list of attribute-value pairs, where the attributes are
|
||||||
@ -240,7 +240,7 @@
|
|||||||
* systems that frequently run out of file handles.</td>
|
* systems that frequently run out of file handles.</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
* <tr>
|
* <tr>
|
||||||
* <td>{@link org.apache.lucene.codecs.lucene90.Lucene90FieldInfosFormat Fields}</td>
|
* <td>{@link org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat Fields}</td>
|
||||||
* <td>.fnm</td>
|
* <td>.fnm</td>
|
||||||
* <td>Stores information about the fields</td>
|
* <td>Stores information about the fields</td>
|
||||||
* </tr>
|
* </tr>
|
||||||
|
@ -33,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState;
|
|||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -101,7 +102,7 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
KnnVectorsWriter writer = getInstance(fieldInfo);
|
KnnVectorsWriter writer = getInstance(fieldInfo);
|
||||||
return writer.addField(fieldInfo);
|
return writer.addField(fieldInfo);
|
||||||
}
|
}
|
||||||
@ -267,6 +268,17 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||||
|
if (knnVectorsReader == null) {
|
||||||
|
return new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||||
|
} else {
|
||||||
|
return knnVectorsReader.searchExhaustively(field, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
IOUtils.close(fields.values());
|
IOUtils.close(fields.values());
|
||||||
|
@ -24,6 +24,7 @@ import org.apache.lucene.index.DocValuesType;
|
|||||||
import org.apache.lucene.index.IndexOptions;
|
import org.apache.lucene.index.IndexOptions;
|
||||||
import org.apache.lucene.index.IndexableFieldType;
|
import org.apache.lucene.index.IndexableFieldType;
|
||||||
import org.apache.lucene.index.PointValues;
|
import org.apache.lucene.index.PointValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ public class FieldType implements IndexableFieldType {
|
|||||||
private int indexDimensionCount;
|
private int indexDimensionCount;
|
||||||
private int dimensionNumBytes;
|
private int dimensionNumBytes;
|
||||||
private int vectorDimension;
|
private int vectorDimension;
|
||||||
|
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
private Map<String, String> attributes;
|
private Map<String, String> attributes;
|
||||||
|
|
||||||
@ -62,6 +64,7 @@ public class FieldType implements IndexableFieldType {
|
|||||||
this.indexDimensionCount = ref.pointIndexDimensionCount();
|
this.indexDimensionCount = ref.pointIndexDimensionCount();
|
||||||
this.dimensionNumBytes = ref.pointNumBytes();
|
this.dimensionNumBytes = ref.pointNumBytes();
|
||||||
this.vectorDimension = ref.vectorDimension();
|
this.vectorDimension = ref.vectorDimension();
|
||||||
|
this.vectorEncoding = ref.vectorEncoding();
|
||||||
this.vectorSimilarityFunction = ref.vectorSimilarityFunction();
|
this.vectorSimilarityFunction = ref.vectorSimilarityFunction();
|
||||||
if (ref.getAttributes() != null) {
|
if (ref.getAttributes() != null) {
|
||||||
this.attributes = new HashMap<>(ref.getAttributes());
|
this.attributes = new HashMap<>(ref.getAttributes());
|
||||||
@ -371,8 +374,8 @@ public class FieldType implements IndexableFieldType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Enable vector indexing, with the specified number of dimensions and distance function. */
|
/** Enable vector indexing, with the specified number of dimensions and distance function. */
|
||||||
public void setVectorDimensionsAndSimilarityFunction(
|
public void setVectorAttributes(
|
||||||
int numDimensions, VectorSimilarityFunction distFunc) {
|
int numDimensions, VectorEncoding encoding, VectorSimilarityFunction similarity) {
|
||||||
checkIfFrozen();
|
checkIfFrozen();
|
||||||
if (numDimensions <= 0) {
|
if (numDimensions <= 0) {
|
||||||
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
|
throw new IllegalArgumentException("vector numDimensions must be > 0; got " + numDimensions);
|
||||||
@ -385,7 +388,8 @@ public class FieldType implements IndexableFieldType {
|
|||||||
+ numDimensions);
|
+ numDimensions);
|
||||||
}
|
}
|
||||||
this.vectorDimension = numDimensions;
|
this.vectorDimension = numDimensions;
|
||||||
this.vectorSimilarityFunction = Objects.requireNonNull(distFunc);
|
this.vectorSimilarityFunction = Objects.requireNonNull(similarity);
|
||||||
|
this.vectorEncoding = Objects.requireNonNull(encoding);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@ -393,6 +397,11 @@ public class FieldType implements IndexableFieldType {
|
|||||||
return vectorDimension;
|
return vectorDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorEncoding vectorEncoding() {
|
||||||
|
return vectorEncoding;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public VectorSimilarityFunction vectorSimilarityFunction() {
|
public VectorSimilarityFunction vectorSimilarityFunction() {
|
||||||
return vectorSimilarityFunction;
|
return vectorSimilarityFunction;
|
||||||
|
@ -17,8 +17,10 @@
|
|||||||
|
|
||||||
package org.apache.lucene.document;
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -39,7 +41,18 @@ public class KnnVectorField extends Field {
|
|||||||
if (v == null) {
|
if (v == null) {
|
||||||
throw new IllegalArgumentException("vector value must not be null");
|
throw new IllegalArgumentException("vector value must not be null");
|
||||||
}
|
}
|
||||||
int dimension = v.length;
|
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||||
|
if (v == null) {
|
||||||
|
throw new IllegalArgumentException("vector value must not be null");
|
||||||
|
}
|
||||||
|
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static FieldType createType(
|
||||||
|
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||||
if (dimension == 0) {
|
if (dimension == 0) {
|
||||||
throw new IllegalArgumentException("cannot index an empty vector");
|
throw new IllegalArgumentException("cannot index an empty vector");
|
||||||
}
|
}
|
||||||
@ -51,13 +64,13 @@ public class KnnVectorField extends Field {
|
|||||||
throw new IllegalArgumentException("similarity function must not be null");
|
throw new IllegalArgumentException("similarity function must not be null");
|
||||||
}
|
}
|
||||||
FieldType type = new FieldType();
|
FieldType type = new FieldType();
|
||||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||||
type.freeze();
|
type.freeze();
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A convenience method for creating a vector field type.
|
* A convenience method for creating a vector field type with the default FLOAT32 encoding.
|
||||||
*
|
*
|
||||||
* @param dimension dimension of vectors
|
* @param dimension dimension of vectors
|
||||||
* @param similarityFunction a function defining vector proximity.
|
* @param similarityFunction a function defining vector proximity.
|
||||||
@ -65,8 +78,21 @@ public class KnnVectorField extends Field {
|
|||||||
*/
|
*/
|
||||||
public static FieldType createFieldType(
|
public static FieldType createFieldType(
|
||||||
int dimension, VectorSimilarityFunction similarityFunction) {
|
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||||
|
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A convenience method for creating a vector field type.
|
||||||
|
*
|
||||||
|
* @param dimension dimension of vectors
|
||||||
|
* @param vectorEncoding the encoding of the scalar values
|
||||||
|
* @param similarityFunction a function defining vector proximity.
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||||
|
*/
|
||||||
|
public static FieldType createFieldType(
|
||||||
|
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||||
FieldType type = new FieldType();
|
FieldType type = new FieldType();
|
||||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||||
type.freeze();
|
type.freeze();
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
@ -74,8 +100,8 @@ public class KnnVectorField extends Field {
|
|||||||
/**
|
/**
|
||||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||||
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be
|
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||||
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
|
* be unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
|
||||||
*
|
*
|
||||||
* @param name field name
|
* @param name field name
|
||||||
* @param vector value
|
* @param vector value
|
||||||
@ -88,6 +114,23 @@ public class KnnVectorField extends Field {
|
|||||||
fieldsData = vector;
|
fieldsData = vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||||
|
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||||
|
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||||
|
* be constant-length.
|
||||||
|
*
|
||||||
|
* @param name field name
|
||||||
|
* @param vector value
|
||||||
|
* @param similarityFunction a function defining vector proximity.
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||||
|
* dimension > 1024.
|
||||||
|
*/
|
||||||
|
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||||
|
super(name, createType(vector, similarityFunction));
|
||||||
|
fieldsData = vector;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||||
* single-valued: each document has either one value or no value. Vectors of a single field share
|
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||||
@ -117,6 +160,21 @@ public class KnnVectorField extends Field {
|
|||||||
fieldsData = vector;
|
fieldsData = vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||||
|
* no value. Vectors of a single field share the same dimension and similarity function.
|
||||||
|
*
|
||||||
|
* @param name field name
|
||||||
|
* @param vector value
|
||||||
|
* @param fieldType field type
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||||
|
* dimension > 1024.
|
||||||
|
*/
|
||||||
|
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||||
|
super(name, fieldType);
|
||||||
|
fieldsData = vector;
|
||||||
|
}
|
||||||
|
|
||||||
/** Return the vector value of this field */
|
/** Return the vector value of this field */
|
||||||
public float[] vectorValue() {
|
public float[] vectorValue() {
|
||||||
return (float[]) fieldsData;
|
return (float[]) fieldsData;
|
||||||
|
@ -45,7 +45,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||||||
protected BufferingKnnVectorsWriter() {}
|
protected BufferingKnnVectorsWriter() {}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<float[]> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
FieldWriter newField = new FieldWriter(fieldInfo);
|
FieldWriter newField = new FieldWriter(fieldInfo);
|
||||||
fields.add(newField);
|
fields.add(newField);
|
||||||
return newField;
|
return newField;
|
||||||
@ -88,6 +88,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
writeField(fieldData.fieldInfo, knnVectorsReader, maxDoc);
|
||||||
@ -122,6 +128,12 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||||
@ -137,7 +149,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||||||
protected abstract void writeField(
|
protected abstract void writeField(
|
||||||
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
|
FieldInfo fieldInfo, KnnVectorsReader knnVectorsReader, int maxDoc) throws IOException;
|
||||||
|
|
||||||
private static class FieldWriter extends KnnFieldVectorsWriter {
|
private static class FieldWriter extends KnnFieldVectorsWriter<float[]> {
|
||||||
private final FieldInfo fieldInfo;
|
private final FieldInfo fieldInfo;
|
||||||
private final int dim;
|
private final int dim;
|
||||||
private final DocsWithFieldSet docsWithField;
|
private final DocsWithFieldSet docsWithField;
|
||||||
@ -153,35 +165,45 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addValue(int docID, float[] vectorValue) {
|
public void addValue(int docID, Object value) {
|
||||||
if (docID == lastDocID) {
|
if (docID == lastDocID) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"VectorValuesField \""
|
"VectorValuesField \""
|
||||||
+ fieldInfo.name
|
+ fieldInfo.name
|
||||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
}
|
}
|
||||||
if (vectorValue.length != dim) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Attempt to index a vector of dimension "
|
|
||||||
+ vectorValue.length
|
|
||||||
+ " but \""
|
|
||||||
+ fieldInfo.name
|
|
||||||
+ "\" has dimension "
|
|
||||||
+ dim);
|
|
||||||
}
|
|
||||||
assert docID > lastDocID;
|
assert docID > lastDocID;
|
||||||
|
float[] vectorValue =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case FLOAT32 -> (float[]) value;
|
||||||
|
case BYTE -> bytesToFloats((BytesRef) value);
|
||||||
|
};
|
||||||
docsWithField.add(docID);
|
docsWithField.add(docID);
|
||||||
vectors.add(ArrayUtil.copyOfSubArray(vectorValue, 0, vectorValue.length));
|
vectors.add(copyValue(vectorValue));
|
||||||
lastDocID = docID;
|
lastDocID = docID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private float[] bytesToFloats(BytesRef b) {
|
||||||
|
// This is used only by SimpleTextKnnVectorsWriter
|
||||||
|
float[] floats = new float[dim];
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
floats[i] = b.bytes[i + b.offset];
|
||||||
|
}
|
||||||
|
return floats;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] copyValue(float[] vectorValue) {
|
||||||
|
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
if (vectors.size() == 0) return 0;
|
if (vectors.size() == 0) return 0;
|
||||||
return docsWithField.ramBytesUsed()
|
return docsWithField.ramBytesUsed()
|
||||||
+ vectors.size()
|
+ vectors.size()
|
||||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||||
+ vectors.size() * vectors.get(0).length * Float.BYTES;
|
+ vectors.size() * dim * Float.BYTES;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||||||
import org.apache.lucene.codecs.PointsReader;
|
import org.apache.lucene.codecs.PointsReader;
|
||||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||||
import org.apache.lucene.codecs.TermVectorsReader;
|
import org.apache.lucene.codecs.TermVectorsReader;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
@ -235,6 +236,19 @@ public abstract class CodecReader extends LeafReader {
|
|||||||
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
return getVectorReader().search(field, target, k, acceptDocs, visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
ensureOpen();
|
||||||
|
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||||
|
if (fi == null || fi.getVectorDimension() == 0) {
|
||||||
|
// Field does not exist or does not index vectors
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return getVectorReader().searchExhaustively(field, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doClose() throws IOException {}
|
protected void doClose() throws IOException {}
|
||||||
|
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
package org.apache.lucene.index;
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
@ -58,6 +59,12 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final void checkIntegrity() throws IOException {
|
public final void checkIntegrity() throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
@ -56,6 +56,7 @@ public final class FieldInfo {
|
|||||||
|
|
||||||
// if it is a positive value, it means this field indexes vectors
|
// if it is a positive value, it means this field indexes vectors
|
||||||
private final int vectorDimension;
|
private final int vectorDimension;
|
||||||
|
private final VectorEncoding vectorEncoding;
|
||||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||||
|
|
||||||
// whether this field is used as the soft-deletes field
|
// whether this field is used as the soft-deletes field
|
||||||
@ -80,6 +81,7 @@ public final class FieldInfo {
|
|||||||
int pointIndexDimensionCount,
|
int pointIndexDimensionCount,
|
||||||
int pointNumBytes,
|
int pointNumBytes,
|
||||||
int vectorDimension,
|
int vectorDimension,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction vectorSimilarityFunction,
|
VectorSimilarityFunction vectorSimilarityFunction,
|
||||||
boolean softDeletesField) {
|
boolean softDeletesField) {
|
||||||
this.name = Objects.requireNonNull(name);
|
this.name = Objects.requireNonNull(name);
|
||||||
@ -105,6 +107,7 @@ public final class FieldInfo {
|
|||||||
this.pointIndexDimensionCount = pointIndexDimensionCount;
|
this.pointIndexDimensionCount = pointIndexDimensionCount;
|
||||||
this.pointNumBytes = pointNumBytes;
|
this.pointNumBytes = pointNumBytes;
|
||||||
this.vectorDimension = vectorDimension;
|
this.vectorDimension = vectorDimension;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||||
this.softDeletesField = softDeletesField;
|
this.softDeletesField = softDeletesField;
|
||||||
this.checkConsistency();
|
this.checkConsistency();
|
||||||
@ -229,8 +232,10 @@ public final class FieldInfo {
|
|||||||
verifySameVectorOptions(
|
verifySameVectorOptions(
|
||||||
fieldName,
|
fieldName,
|
||||||
this.vectorDimension,
|
this.vectorDimension,
|
||||||
|
this.vectorEncoding,
|
||||||
this.vectorSimilarityFunction,
|
this.vectorSimilarityFunction,
|
||||||
o.vectorDimension,
|
o.vectorDimension,
|
||||||
|
o.vectorEncoding,
|
||||||
o.vectorSimilarityFunction);
|
o.vectorSimilarityFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -347,19 +352,25 @@ public final class FieldInfo {
|
|||||||
static void verifySameVectorOptions(
|
static void verifySameVectorOptions(
|
||||||
String fieldName,
|
String fieldName,
|
||||||
int vd1,
|
int vd1,
|
||||||
|
VectorEncoding ve1,
|
||||||
VectorSimilarityFunction vsf1,
|
VectorSimilarityFunction vsf1,
|
||||||
int vd2,
|
int vd2,
|
||||||
|
VectorEncoding ve2,
|
||||||
VectorSimilarityFunction vsf2) {
|
VectorSimilarityFunction vsf2) {
|
||||||
if (vd1 != vd2 || vsf1 != vsf2) {
|
if (vd1 != vd2 || vsf1 != vsf2 || ve1 != ve2) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"cannot change field \""
|
"cannot change field \""
|
||||||
+ fieldName
|
+ fieldName
|
||||||
+ "\" from vector dimension="
|
+ "\" from vector dimension="
|
||||||
+ vd1
|
+ vd1
|
||||||
|
+ ", vector encoding="
|
||||||
|
+ ve1
|
||||||
+ ", vector similarity function="
|
+ ", vector similarity function="
|
||||||
+ vsf1
|
+ vsf1
|
||||||
+ " to inconsistent vector dimension="
|
+ " to inconsistent vector dimension="
|
||||||
+ vd2
|
+ vd2
|
||||||
|
+ ", vector encoding="
|
||||||
|
+ ve2
|
||||||
+ ", vector similarity function="
|
+ ", vector similarity function="
|
||||||
+ vsf2);
|
+ vsf2);
|
||||||
}
|
}
|
||||||
@ -470,6 +481,11 @@ public final class FieldInfo {
|
|||||||
return vectorDimension;
|
return vectorDimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns the number of dimensions of the vector value */
|
||||||
|
public VectorEncoding getVectorEncoding() {
|
||||||
|
return vectorEncoding;
|
||||||
|
}
|
||||||
|
|
||||||
/** Returns {@link VectorSimilarityFunction} for the field */
|
/** Returns {@link VectorSimilarityFunction} for the field */
|
||||||
public VectorSimilarityFunction getVectorSimilarityFunction() {
|
public VectorSimilarityFunction getVectorSimilarityFunction() {
|
||||||
return vectorSimilarityFunction;
|
return vectorSimilarityFunction;
|
||||||
|
@ -308,10 +308,15 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
|
|
||||||
static final class FieldVectorProperties {
|
static final class FieldVectorProperties {
|
||||||
final int numDimensions;
|
final int numDimensions;
|
||||||
|
final VectorEncoding vectorEncoding;
|
||||||
final VectorSimilarityFunction similarityFunction;
|
final VectorSimilarityFunction similarityFunction;
|
||||||
|
|
||||||
FieldVectorProperties(int numDimensions, VectorSimilarityFunction similarityFunction) {
|
FieldVectorProperties(
|
||||||
|
int numDimensions,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction) {
|
||||||
this.numDimensions = numDimensions;
|
this.numDimensions = numDimensions;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
this.similarityFunction = similarityFunction;
|
this.similarityFunction = similarityFunction;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -401,7 +406,8 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
fi.getPointNumBytes()));
|
fi.getPointNumBytes()));
|
||||||
vectorProps.put(
|
vectorProps.put(
|
||||||
fieldName,
|
fieldName,
|
||||||
new FieldVectorProperties(fi.getVectorDimension(), fi.getVectorSimilarityFunction()));
|
new FieldVectorProperties(
|
||||||
|
fi.getVectorDimension(), fi.getVectorEncoding(), fi.getVectorSimilarityFunction()));
|
||||||
}
|
}
|
||||||
return fieldNumber.intValue();
|
return fieldNumber.intValue();
|
||||||
}
|
}
|
||||||
@ -459,8 +465,10 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
verifySameVectorOptions(
|
verifySameVectorOptions(
|
||||||
fieldName,
|
fieldName,
|
||||||
props.numDimensions,
|
props.numDimensions,
|
||||||
|
props.vectorEncoding,
|
||||||
props.similarityFunction,
|
props.similarityFunction,
|
||||||
fi.getVectorDimension(),
|
fi.getVectorDimension(),
|
||||||
|
fi.getVectorEncoding(),
|
||||||
fi.getVectorSimilarityFunction());
|
fi.getVectorSimilarityFunction());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -503,6 +511,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
|
(softDeletesFieldName != null && softDeletesFieldName.equals(fieldName)));
|
||||||
addOrGet(fi);
|
addOrGet(fi);
|
||||||
@ -584,6 +593,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
isSoftDeletesField);
|
isSoftDeletesField);
|
||||||
}
|
}
|
||||||
@ -698,6 +708,7 @@ public class FieldInfos implements Iterable<FieldInfo> {
|
|||||||
fi.getPointIndexDimensionCount(),
|
fi.getPointIndexDimensionCount(),
|
||||||
fi.getPointNumBytes(),
|
fi.getPointNumBytes(),
|
||||||
fi.getVectorDimension(),
|
fi.getVectorDimension(),
|
||||||
|
fi.getVectorEncoding(),
|
||||||
fi.getVectorSimilarityFunction(),
|
fi.getVectorSimilarityFunction(),
|
||||||
fi.isSoftDeletesField());
|
fi.isSoftDeletesField());
|
||||||
byName.put(fiNew.getName(), fiNew);
|
byName.put(fiNew.getName(), fiNew);
|
||||||
|
@ -18,6 +18,7 @@ package org.apache.lucene.index;
|
|||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.AttributeSource;
|
import org.apache.lucene.util.AttributeSource;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
@ -357,6 +358,12 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Fields getTermVectors(int docID) throws IOException {
|
public Fields getTermVectors(int docID) throws IOException {
|
||||||
ensureOpen();
|
ensureOpen();
|
||||||
|
@ -101,6 +101,9 @@ public interface IndexableFieldType {
|
|||||||
/** The number of dimensions of the field's vector value */
|
/** The number of dimensions of the field's vector value */
|
||||||
int vectorDimension();
|
int vectorDimension();
|
||||||
|
|
||||||
|
/** The {@link VectorEncoding} of the field's vector value */
|
||||||
|
VectorEncoding vectorEncoding();
|
||||||
|
|
||||||
/** The {@link VectorSimilarityFunction} of the field's vector value */
|
/** The {@link VectorSimilarityFunction} of the field's vector value */
|
||||||
VectorSimilarityFunction vectorSimilarityFunction();
|
VectorSimilarityFunction vectorSimilarityFunction();
|
||||||
|
|
||||||
|
@ -628,6 +628,7 @@ final class IndexingChain implements Accountable {
|
|||||||
s.pointIndexDimensionCount,
|
s.pointIndexDimensionCount,
|
||||||
s.pointNumBytes,
|
s.pointNumBytes,
|
||||||
s.vectorDimension,
|
s.vectorDimension,
|
||||||
|
s.vectorEncoding,
|
||||||
s.vectorSimilarityFunction,
|
s.vectorSimilarityFunction,
|
||||||
pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName())));
|
pf.fieldName.equals(fieldInfos.getSoftDeletesFieldName())));
|
||||||
pf.setFieldInfo(fi);
|
pf.setFieldInfo(fi);
|
||||||
@ -712,7 +713,11 @@ final class IndexingChain implements Accountable {
|
|||||||
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
||||||
}
|
}
|
||||||
if (fieldType.vectorDimension() != 0) {
|
if (fieldType.vectorDimension() != 0) {
|
||||||
pf.knnFieldVectorsWriter.addValue(docID, ((KnnVectorField) field).vectorValue());
|
switch (fieldType.vectorEncoding()) {
|
||||||
|
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
|
||||||
|
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
|
||||||
|
docID, ((KnnVectorField) field).vectorValue());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return indexedField;
|
return indexedField;
|
||||||
}
|
}
|
||||||
@ -776,7 +781,10 @@ final class IndexingChain implements Accountable {
|
|||||||
fieldType.pointNumBytes());
|
fieldType.pointNumBytes());
|
||||||
}
|
}
|
||||||
if (fieldType.vectorDimension() != 0) {
|
if (fieldType.vectorDimension() != 0) {
|
||||||
schema.setVectors(fieldType.vectorSimilarityFunction(), fieldType.vectorDimension());
|
schema.setVectors(
|
||||||
|
fieldType.vectorEncoding(),
|
||||||
|
fieldType.vectorSimilarityFunction(),
|
||||||
|
fieldType.vectorDimension());
|
||||||
}
|
}
|
||||||
if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) {
|
if (fieldType.getAttributes() != null && fieldType.getAttributes().isEmpty() == false) {
|
||||||
schema.updateAttributes(fieldType.getAttributes());
|
schema.updateAttributes(fieldType.getAttributes());
|
||||||
@ -988,7 +996,7 @@ final class IndexingChain implements Accountable {
|
|||||||
PointValuesWriter pointValuesWriter;
|
PointValuesWriter pointValuesWriter;
|
||||||
|
|
||||||
// Non-null if this field had vectors in this segment
|
// Non-null if this field had vectors in this segment
|
||||||
KnnFieldVectorsWriter knnFieldVectorsWriter;
|
KnnFieldVectorsWriter<?> knnFieldVectorsWriter;
|
||||||
|
|
||||||
/** We use this to know when a PerField is seen for the first time in the current document. */
|
/** We use this to know when a PerField is seen for the first time in the current document. */
|
||||||
long fieldGen = -1;
|
long fieldGen = -1;
|
||||||
@ -1281,6 +1289,7 @@ final class IndexingChain implements Accountable {
|
|||||||
private int pointIndexDimensionCount = 0;
|
private int pointIndexDimensionCount = 0;
|
||||||
private int pointNumBytes = 0;
|
private int pointNumBytes = 0;
|
||||||
private int vectorDimension = 0;
|
private int vectorDimension = 0;
|
||||||
|
private VectorEncoding vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
private VectorSimilarityFunction vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
|
||||||
private static String errMsg =
|
private static String errMsg =
|
||||||
@ -1361,11 +1370,14 @@ final class IndexingChain implements Accountable {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void setVectors(VectorSimilarityFunction similarityFunction, int dimension) {
|
void setVectors(
|
||||||
|
VectorEncoding encoding, VectorSimilarityFunction similarityFunction, int dimension) {
|
||||||
if (vectorDimension == 0) {
|
if (vectorDimension == 0) {
|
||||||
this.vectorDimension = dimension;
|
this.vectorEncoding = encoding;
|
||||||
this.vectorSimilarityFunction = similarityFunction;
|
this.vectorSimilarityFunction = similarityFunction;
|
||||||
|
this.vectorDimension = dimension;
|
||||||
} else {
|
} else {
|
||||||
|
assertSame("vector encoding", vectorEncoding, encoding);
|
||||||
assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction);
|
assertSame("vector similarity function", vectorSimilarityFunction, similarityFunction);
|
||||||
assertSame("vector dimension", vectorDimension, dimension);
|
assertSame("vector dimension", vectorDimension, dimension);
|
||||||
}
|
}
|
||||||
@ -1381,6 +1393,7 @@ final class IndexingChain implements Accountable {
|
|||||||
pointIndexDimensionCount = 0;
|
pointIndexDimensionCount = 0;
|
||||||
pointNumBytes = 0;
|
pointNumBytes = 0;
|
||||||
vectorDimension = 0;
|
vectorDimension = 0;
|
||||||
|
vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
vectorSimilarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1391,6 +1404,7 @@ final class IndexingChain implements Accountable {
|
|||||||
assertSame("doc values type", fi.getDocValuesType(), docValuesType);
|
assertSame("doc values type", fi.getDocValuesType(), docValuesType);
|
||||||
assertSame(
|
assertSame(
|
||||||
"vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction);
|
"vector similarity function", fi.getVectorSimilarityFunction(), vectorSimilarityFunction);
|
||||||
|
assertSame("vector encoding", fi.getVectorEncoding(), vectorEncoding);
|
||||||
assertSame("vector dimension", fi.getVectorDimension(), vectorDimension);
|
assertSame("vector dimension", fi.getVectorDimension(), vectorDimension);
|
||||||
assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount);
|
assertSame("point dimension", fi.getPointDimensionCount(), pointDimensionCount);
|
||||||
assertSame(
|
assertSame(
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
package org.apache.lucene.index;
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.search.TotalHits;
|
import org.apache.lucene.search.TotalHits;
|
||||||
@ -235,6 +236,30 @@ public abstract class LeafReader extends IndexReader {
|
|||||||
public abstract TopDocs searchNearestVectors(
|
public abstract TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||||
|
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||||
|
* is derived from the vector similarity in a way that ensures scores are positive and that a
|
||||||
|
* larger score corresponds to a higher ranking.
|
||||||
|
*
|
||||||
|
* <p>The search is exact, meaning the results are guaranteed to be the true k closest neighbors.
|
||||||
|
* This typically requires an exhaustive scan of all candidate documents.
|
||||||
|
*
|
||||||
|
* <p>The returned {@link TopDocs} will contain a {@link ScoreDoc} for each nearest neighbor,
|
||||||
|
* sorted in order of their similarity to the query vector (decreasing scores). The {@link
|
||||||
|
* TotalHits} contains the number of documents visited during the search.
|
||||||
|
*
|
||||||
|
* @param field the vector field to search
|
||||||
|
* @param target the vector-valued query
|
||||||
|
* @param k the number of docs to return
|
||||||
|
* @param acceptDocs {@link DocIdSetIterator} that represents the allowed documents to match, or
|
||||||
|
* {@code null} if they are all allowed to match.
|
||||||
|
* @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores.
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get the {@link FieldInfos} describing all fields in this reader.
|
* Get the {@link FieldInfos} describing all fields in this reader.
|
||||||
*
|
*
|
||||||
|
@ -26,6 +26,7 @@ import java.util.Objects;
|
|||||||
import java.util.Set;
|
import java.util.Set;
|
||||||
import java.util.SortedMap;
|
import java.util.SortedMap;
|
||||||
import java.util.TreeMap;
|
import java.util.TreeMap;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
@ -403,6 +404,16 @@ public class ParallelLeafReader extends LeafReader {
|
|||||||
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
: reader.searchNearestVectors(fieldName, target, k, acceptDocs, visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String fieldName, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
ensureOpen();
|
||||||
|
LeafReader reader = fieldToReader.get(fieldName);
|
||||||
|
return reader == null
|
||||||
|
? null
|
||||||
|
: reader.searchNearestVectorsExhaustively(fieldName, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
ensureOpen();
|
ensureOpen();
|
||||||
|
@ -722,6 +722,7 @@ final class ReadersAndUpdates {
|
|||||||
fi.getPointIndexDimensionCount(),
|
fi.getPointIndexDimensionCount(),
|
||||||
fi.getPointNumBytes(),
|
fi.getPointNumBytes(),
|
||||||
fi.getVectorDimension(),
|
fi.getVectorDimension(),
|
||||||
|
fi.getVectorEncoding(),
|
||||||
fi.getVectorSimilarityFunction(),
|
fi.getVectorSimilarityFunction(),
|
||||||
fi.isSoftDeletesField());
|
fi.isSoftDeletesField());
|
||||||
}
|
}
|
||||||
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||||||
import org.apache.lucene.codecs.PointsReader;
|
import org.apache.lucene.codecs.PointsReader;
|
||||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||||
import org.apache.lucene.codecs.TermVectorsReader;
|
import org.apache.lucene.codecs.TermVectorsReader;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
@ -172,6 +173,12 @@ public final class SlowCodecReaderWrapper {
|
|||||||
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
return reader.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
return reader.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() {
|
public void checkIntegrity() {
|
||||||
// We already checkIntegrity the entire reader up front
|
// We already checkIntegrity the entire reader up front
|
||||||
|
@ -31,6 +31,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||||||
import org.apache.lucene.codecs.PointsReader;
|
import org.apache.lucene.codecs.PointsReader;
|
||||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||||
import org.apache.lucene.codecs.TermVectorsReader;
|
import org.apache.lucene.codecs.TermVectorsReader;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
@ -389,6 +390,12 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
delegate.close();
|
delegate.close();
|
||||||
|
@ -0,0 +1,45 @@
|
|||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
|
/** The numeric datatype of the vector values. */
|
||||||
|
public enum VectorEncoding {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encodes vector using 8 bits of precision per sample. Use only with DOT_PRODUCT similarity.
|
||||||
|
* NOTE: this can enable significant storage savings and faster searches, at the cost of some
|
||||||
|
* possible loss of precision. In order to use it, all vectors must be of the same norm, as
|
||||||
|
* measured by the sum of the squares of the scalar values, and those values must be in the range
|
||||||
|
* [-128, 127]. This applies to both document and query vectors. Using nonconforming vectors can
|
||||||
|
* result in errors or poor search results.
|
||||||
|
*/
|
||||||
|
BYTE(1),
|
||||||
|
|
||||||
|
/** Encodes vector using 32 bits of precision per sample in IEEE floating point format. */
|
||||||
|
FLOAT32(4);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of bytes required to encode a scalar in this format. A vector will require dimension
|
||||||
|
* * byteSize.
|
||||||
|
*/
|
||||||
|
public final int byteSize;
|
||||||
|
|
||||||
|
VectorEncoding(int byteSize) {
|
||||||
|
this.byteSize = byteSize;
|
||||||
|
}
|
||||||
|
}
|
@ -18,6 +18,8 @@ package org.apache.lucene.index;
|
|||||||
|
|
||||||
import static org.apache.lucene.util.VectorUtil.*;
|
import static org.apache.lucene.util.VectorUtil.*;
|
||||||
|
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Vector similarity function; used in search to return top K most similar vectors to a target
|
* Vector similarity function; used in search to return top K most similar vectors to a target
|
||||||
* vector. This is a label describing the method used during indexing and searching of the vectors
|
* vector. This is a label describing the method used during indexing and searching of the vectors
|
||||||
@ -31,6 +33,11 @@ public enum VectorSimilarityFunction {
|
|||||||
public float compare(float[] v1, float[] v2) {
|
public float compare(float[] v1, float[] v2) {
|
||||||
return 1 / (1 + squareDistance(v1, v2));
|
return 1 / (1 + squareDistance(v1, v2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float compare(BytesRef v1, BytesRef v2) {
|
||||||
|
return 1 / (1 + squareDistance(v1, v2));
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -44,6 +51,11 @@ public enum VectorSimilarityFunction {
|
|||||||
public float compare(float[] v1, float[] v2) {
|
public float compare(float[] v1, float[] v2) {
|
||||||
return (1 + dotProduct(v1, v2)) / 2;
|
return (1 + dotProduct(v1, v2)) / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float compare(BytesRef v1, BytesRef v2) {
|
||||||
|
return dotProductScore(v1, v2);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -57,6 +69,11 @@ public enum VectorSimilarityFunction {
|
|||||||
public float compare(float[] v1, float[] v2) {
|
public float compare(float[] v1, float[] v2) {
|
||||||
return (1 + cosine(v1, v2)) / 2;
|
return (1 + cosine(v1, v2)) / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float compare(BytesRef v1, BytesRef v2) {
|
||||||
|
return (1 + cosine(v1, v2)) / 2;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -68,4 +85,15 @@ public enum VectorSimilarityFunction {
|
|||||||
* @return the value of the similarity function applied to the two vectors
|
* @return the value of the similarity function applied to the two vectors
|
||||||
*/
|
*/
|
||||||
public abstract float compare(float[] v1, float[] v2);
|
public abstract float compare(float[] v1, float[] v2);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculates a similarity score between the two vectors with a specified function. Higher
|
||||||
|
* similarity scores correspond to closer vectors. The offsets and lengths of the BytesRefs
|
||||||
|
* determine the vector data that is compared. Each (signed) byte represents a vector dimension.
|
||||||
|
*
|
||||||
|
* @param v1 a vector
|
||||||
|
* @param v2 another vector, of the same dimension
|
||||||
|
* @return the value of the similarity function applied to the two vectors
|
||||||
|
*/
|
||||||
|
public abstract float compare(BytesRef v1, BytesRef v2);
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ class VectorValuesConsumer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
initKnnVectorsWriter(fieldInfo.name);
|
initKnnVectorsWriter(fieldInfo.name);
|
||||||
return writer.addField(fieldInfo);
|
return writer.addField(fieldInfo);
|
||||||
}
|
}
|
||||||
|
@ -24,11 +24,8 @@ import java.util.Comparator;
|
|||||||
import java.util.Objects;
|
import java.util.Objects;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
|
||||||
import org.apache.lucene.index.VectorValues;
|
|
||||||
import org.apache.lucene.util.BitSet;
|
import org.apache.lucene.util.BitSet;
|
||||||
import org.apache.lucene.util.BitSetIterator;
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
@ -133,22 +130,21 @@ public class KnnVectorQuery extends Query {
|
|||||||
return NO_RESULTS;
|
return NO_RESULTS;
|
||||||
}
|
}
|
||||||
|
|
||||||
BitSet bitSet = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, maxDoc);
|
||||||
BitSetIterator filterIterator = new BitSetIterator(bitSet, bitSet.cardinality());
|
|
||||||
|
|
||||||
if (filterIterator.cost() <= k) {
|
if (acceptDocs.cardinality() <= k) {
|
||||||
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
// If there are <= k possible matches, short-circuit and perform exact search, since HNSW
|
||||||
// must always visit at least k documents
|
// must always visit at least k documents
|
||||||
return exactSearch(ctx, filterIterator);
|
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Perform the approximate kNN search
|
// Perform the approximate kNN search
|
||||||
TopDocs results = approximateSearch(ctx, bitSet, (int) filterIterator.cost());
|
TopDocs results = approximateSearch(ctx, acceptDocs, acceptDocs.cardinality());
|
||||||
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO) {
|
||||||
return results;
|
return results;
|
||||||
} else {
|
} else {
|
||||||
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
// We stopped the kNN search because it visited too many nodes, so fall back to exact search
|
||||||
return exactSearch(ctx, filterIterator);
|
return exactSearch(ctx, new BitSetIterator(acceptDocs, acceptDocs.cardinality()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -178,45 +174,9 @@ public class KnnVectorQuery extends Query {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We allow this to be overridden so that tests can check what search strategy is used
|
// We allow this to be overridden so that tests can check what search strategy is used
|
||||||
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptIterator)
|
protected TopDocs exactSearch(LeafReaderContext context, DocIdSetIterator acceptDocs)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
return context.reader().searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||||
if (fi == null || fi.getVectorDimension() == 0) {
|
|
||||||
// The field does not exist or does not index vectors
|
|
||||||
return NO_RESULTS;
|
|
||||||
}
|
|
||||||
|
|
||||||
VectorSimilarityFunction similarityFunction = fi.getVectorSimilarityFunction();
|
|
||||||
VectorValues vectorValues = context.reader().getVectorValues(field);
|
|
||||||
|
|
||||||
HitQueue queue = new HitQueue(k, true);
|
|
||||||
ScoreDoc topDoc = queue.top();
|
|
||||||
int doc;
|
|
||||||
while ((doc = acceptIterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
|
||||||
int vectorDoc = vectorValues.advance(doc);
|
|
||||||
assert vectorDoc == doc;
|
|
||||||
float[] vector = vectorValues.vectorValue();
|
|
||||||
|
|
||||||
float score = similarityFunction.compare(vector, target);
|
|
||||||
if (score >= topDoc.score) {
|
|
||||||
topDoc.score = score;
|
|
||||||
topDoc.doc = doc;
|
|
||||||
topDoc = queue.updateTop();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove any remaining sentinel values
|
|
||||||
while (queue.size() > 0 && queue.top().score < 0) {
|
|
||||||
queue.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
ScoreDoc[] topScoreDocs = new ScoreDoc[queue.size()];
|
|
||||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
|
||||||
topScoreDocs[i] = queue.pop();
|
|
||||||
}
|
|
||||||
|
|
||||||
TotalHits totalHits = new TotalHits(acceptIterator.cost(), TotalHits.Relation.EQUAL_TO);
|
|
||||||
return new TopDocs(totalHits, topScoreDocs);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||||
|
@ -121,6 +121,24 @@ public final class VectorUtil {
|
|||||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns the cosine similarity between the two vectors. */
|
||||||
|
public static float cosine(BytesRef a, BytesRef b) {
|
||||||
|
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||||
|
int sum = 0;
|
||||||
|
int norm1 = 0;
|
||||||
|
int norm2 = 0;
|
||||||
|
int aOffset = a.offset, bOffset = b.offset;
|
||||||
|
|
||||||
|
for (int i = 0; i < a.length; i++) {
|
||||||
|
byte elem1 = a.bytes[aOffset++];
|
||||||
|
byte elem2 = b.bytes[bOffset++];
|
||||||
|
sum += elem1 * elem2;
|
||||||
|
norm1 += elem1 * elem1;
|
||||||
|
norm2 += elem2 * elem2;
|
||||||
|
}
|
||||||
|
return (float) (sum / Math.sqrt((double) norm1 * (double) norm2));
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns the sum of squared differences of the two vectors.
|
* Returns the sum of squared differences of the two vectors.
|
||||||
*
|
*
|
||||||
@ -135,7 +153,7 @@ public final class VectorUtil {
|
|||||||
int dim = v1.length;
|
int dim = v1.length;
|
||||||
int i;
|
int i;
|
||||||
for (i = 0; i + 8 <= dim; i += 8) {
|
for (i = 0; i + 8 <= dim; i += 8) {
|
||||||
squareSum += squareDistanceUnrolled8(v1, v2, i);
|
squareSum += squareDistanceUnrolled(v1, v2, i);
|
||||||
}
|
}
|
||||||
for (; i < dim; i++) {
|
for (; i < dim; i++) {
|
||||||
float diff = v1[i] - v2[i];
|
float diff = v1[i] - v2[i];
|
||||||
@ -144,7 +162,7 @@ public final class VectorUtil {
|
|||||||
return squareSum;
|
return squareSum;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float squareDistanceUnrolled8(float[] v1, float[] v2, int index) {
|
private static float squareDistanceUnrolled(float[] v1, float[] v2, int index) {
|
||||||
float diff0 = v1[index + 0] - v2[index + 0];
|
float diff0 = v1[index + 0] - v2[index + 0];
|
||||||
float diff1 = v1[index + 1] - v2[index + 1];
|
float diff1 = v1[index + 1] - v2[index + 1];
|
||||||
float diff2 = v1[index + 2] - v2[index + 2];
|
float diff2 = v1[index + 2] - v2[index + 2];
|
||||||
@ -163,6 +181,18 @@ public final class VectorUtil {
|
|||||||
+ diff7 * diff7;
|
+ diff7 * diff7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Returns the sum of squared differences of the two vectors. */
|
||||||
|
public static float squareDistance(BytesRef a, BytesRef b) {
|
||||||
|
// Note: this will not overflow if dim < 2^18, since max(byte * byte) = 2^14.
|
||||||
|
int squareSum = 0;
|
||||||
|
int aOffset = a.offset, bOffset = b.offset;
|
||||||
|
for (int i = 0; i < a.length; i++) {
|
||||||
|
int diff = a.bytes[aOffset++] - b.bytes[bOffset++];
|
||||||
|
squareSum += diff * diff;
|
||||||
|
}
|
||||||
|
return squareSum;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
* Modifies the argument to be unit length, dividing by its l2-norm. IllegalArgumentException is
|
||||||
* thrown for zero vectors.
|
* thrown for zero vectors.
|
||||||
@ -213,4 +243,48 @@ public final class VectorUtil {
|
|||||||
u[i] += v[i];
|
u[i] += v[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dot product computed over signed bytes.
|
||||||
|
*
|
||||||
|
* @param a bytes containing a vector
|
||||||
|
* @param b bytes containing another vector, of the same dimension
|
||||||
|
* @return the value of the dot product of the two vectors
|
||||||
|
*/
|
||||||
|
public static float dotProduct(BytesRef a, BytesRef b) {
|
||||||
|
assert a.length == b.length;
|
||||||
|
int total = 0;
|
||||||
|
int aOffset = a.offset, bOffset = b.offset;
|
||||||
|
for (int i = 0; i < a.length; i++) {
|
||||||
|
total += a.bytes[aOffset++] * b.bytes[bOffset++];
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Dot product score computed over signed bytes, scaled to be in [0, 1].
|
||||||
|
*
|
||||||
|
* @param a bytes containing a vector
|
||||||
|
* @param b bytes containing another vector, of the same dimension
|
||||||
|
* @return the value of the similarity function applied to the two vectors
|
||||||
|
*/
|
||||||
|
public static float dotProductScore(BytesRef a, BytesRef b) {
|
||||||
|
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
|
||||||
|
return (1 + dotProduct(a, b)) / (float) (a.length * (1 << 15));
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert a floating point vector to an array of bytes using casting; the vector values should be
|
||||||
|
* in [-128,127]
|
||||||
|
*
|
||||||
|
* @param vector a vector
|
||||||
|
* @return a new BytesRef containing the vector's values cast to byte.
|
||||||
|
*/
|
||||||
|
public static BytesRef toBytesRef(float[] vector) {
|
||||||
|
BytesRef b = new BytesRef(new byte[vector.length]);
|
||||||
|
for (int i = 0; i < vector.length; i++) {
|
||||||
|
b.bytes[i] = (byte) vector[i];
|
||||||
|
}
|
||||||
|
return b;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,15 +25,19 @@ import java.util.Objects;
|
|||||||
import java.util.SplittableRandom;
|
import java.util.SplittableRandom;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.InfoStream;
|
import org.apache.lucene.util.InfoStream;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
* Builder for HNSW graph. See {@link HnswGraph} for a gloss on the algorithm and the meaning of the
|
||||||
* hyperparameters.
|
* hyperparameters.
|
||||||
|
*
|
||||||
|
* @param <T> the type of vector
|
||||||
*/
|
*/
|
||||||
public final class HnswGraphBuilder {
|
public final class HnswGraphBuilder<T> {
|
||||||
|
|
||||||
/** Default random seed for level generation * */
|
/** Default random seed for level generation * */
|
||||||
private static final long DEFAULT_RAND_SEED = 42;
|
private static final long DEFAULT_RAND_SEED = 42;
|
||||||
@ -49,9 +53,10 @@ public final class HnswGraphBuilder {
|
|||||||
private final NeighborArray scratch;
|
private final NeighborArray scratch;
|
||||||
|
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
|
private final VectorEncoding vectorEncoding;
|
||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues vectorValues;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final HnswGraphSearcher graphSearcher;
|
private final HnswGraphSearcher<T> graphSearcher;
|
||||||
|
|
||||||
final OnHeapHnswGraph hnsw;
|
final OnHeapHnswGraph hnsw;
|
||||||
|
|
||||||
@ -59,7 +64,18 @@ public final class HnswGraphBuilder {
|
|||||||
|
|
||||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||||
// colliding
|
// colliding
|
||||||
private RandomAccessVectorValues buildVectors;
|
private final RandomAccessVectorValues buildVectors;
|
||||||
|
|
||||||
|
public static HnswGraphBuilder<?> create(
|
||||||
|
RandomAccessVectorValuesProducer vectors,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
int M,
|
||||||
|
int beamWidth,
|
||||||
|
long seed)
|
||||||
|
throws IOException {
|
||||||
|
return new HnswGraphBuilder<>(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||||
@ -73,8 +89,9 @@ public final class HnswGraphBuilder {
|
|||||||
* @param seed the seed for a random number generator used during graph construction. Provide this
|
* @param seed the seed for a random number generator used during graph construction. Provide this
|
||||||
* to ensure repeatable construction.
|
* to ensure repeatable construction.
|
||||||
*/
|
*/
|
||||||
public HnswGraphBuilder(
|
private HnswGraphBuilder(
|
||||||
RandomAccessVectorValuesProducer vectors,
|
RandomAccessVectorValuesProducer vectors,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
int M,
|
int M,
|
||||||
int beamWidth,
|
int beamWidth,
|
||||||
@ -82,6 +99,7 @@ public final class HnswGraphBuilder {
|
|||||||
throws IOException {
|
throws IOException {
|
||||||
vectorValues = vectors.randomAccess();
|
vectorValues = vectors.randomAccess();
|
||||||
buildVectors = vectors.randomAccess();
|
buildVectors = vectors.randomAccess();
|
||||||
|
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
|
||||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||||
if (M <= 0) {
|
if (M <= 0) {
|
||||||
throw new IllegalArgumentException("maxConn must be positive");
|
throw new IllegalArgumentException("maxConn must be positive");
|
||||||
@ -97,7 +115,8 @@ public final class HnswGraphBuilder {
|
|||||||
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
int levelOfFirstNode = getRandomGraphLevel(ml, random);
|
||||||
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
this.hnsw = new OnHeapHnswGraph(M, levelOfFirstNode);
|
||||||
this.graphSearcher =
|
this.graphSearcher =
|
||||||
new HnswGraphSearcher(
|
new HnswGraphSearcher<>(
|
||||||
|
vectorEncoding,
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
new NeighborQueue(beamWidth, true),
|
new NeighborQueue(beamWidth, true),
|
||||||
new FixedBitSet(vectorValues.size()));
|
new FixedBitSet(vectorValues.size()));
|
||||||
@ -110,7 +129,7 @@ public final class HnswGraphBuilder {
|
|||||||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||||
* returned values.
|
* returned values.
|
||||||
*
|
*
|
||||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
|
||||||
* accessor for the vectors
|
* accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||||
@ -121,15 +140,19 @@ public final class HnswGraphBuilder {
|
|||||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
||||||
}
|
}
|
||||||
|
addVectors(vectors);
|
||||||
|
return hnsw;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
|
||||||
long start = System.nanoTime(), t = start;
|
long start = System.nanoTime(), t = start;
|
||||||
// start at node 1! node 0 is added implicitly, in the constructor
|
// start at node 1! node 0 is added implicitly, in the constructor
|
||||||
for (int node = 1; node < vectors.size(); node++) {
|
for (int node = 1; node < vectors.size(); node++) {
|
||||||
addGraphNode(node, vectors.vectorValue(node));
|
addGraphNode(node, vectors);
|
||||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||||
t = printGraphBuildStatus(node, start, t);
|
t = printGraphBuildStatus(node, start, t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return hnsw;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Set info-stream to output debugging information * */
|
/** Set info-stream to output debugging information * */
|
||||||
@ -142,7 +165,7 @@ public final class HnswGraphBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Inserts a doc with vector value to the graph */
|
/** Inserts a doc with vector value to the graph */
|
||||||
public void addGraphNode(int node, float[] value) throws IOException {
|
public void addGraphNode(int node, T value) throws IOException {
|
||||||
NeighborQueue candidates;
|
NeighborQueue candidates;
|
||||||
final int nodeLevel = getRandomGraphLevel(ml, random);
|
final int nodeLevel = getRandomGraphLevel(ml, random);
|
||||||
int curMaxLevel = hnsw.numLevels() - 1;
|
int curMaxLevel = hnsw.numLevels() - 1;
|
||||||
@ -167,6 +190,18 @@ public final class HnswGraphBuilder {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
|
||||||
|
addGraphNode(node, getValue(node, values));
|
||||||
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
|
||||||
|
return switch (vectorEncoding) {
|
||||||
|
case BYTE -> (T) values.binaryValue(node);
|
||||||
|
case FLOAT32 -> (T) values.vectorValue(node);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
private long printGraphBuildStatus(int node, long start, long t) {
|
private long printGraphBuildStatus(int node, long start, long t) {
|
||||||
long now = System.nanoTime();
|
long now = System.nanoTime();
|
||||||
infoStream.message(
|
infoStream.message(
|
||||||
@ -215,7 +250,7 @@ public final class HnswGraphBuilder {
|
|||||||
int cNode = candidates.node[i];
|
int cNode = candidates.node[i];
|
||||||
float cScore = candidates.score[i];
|
float cScore = candidates.score[i];
|
||||||
assert cNode < hnsw.size();
|
assert cNode < hnsw.size();
|
||||||
if (diversityCheck(vectorValues.vectorValue(cNode), cScore, neighbors, buildVectors)) {
|
if (diversityCheck(cNode, cScore, neighbors)) {
|
||||||
neighbors.add(cNode, cScore);
|
neighbors.add(cNode, cScore);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -237,19 +272,38 @@ public final class HnswGraphBuilder {
|
|||||||
* @param score the score of the new candidate and node n, to be compared with scores of the
|
* @param score the score of the new candidate and node n, to be compared with scores of the
|
||||||
* candidate and n's neighbors
|
* candidate and n's neighbors
|
||||||
* @param neighbors the neighbors selected so far
|
* @param neighbors the neighbors selected so far
|
||||||
* @param vectorValues source of values used for making comparisons between candidate and existing
|
|
||||||
* neighbors
|
|
||||||
* @return whether the candidate is diverse given the existing neighbors
|
* @return whether the candidate is diverse given the existing neighbors
|
||||||
*/
|
*/
|
||||||
private boolean diversityCheck(
|
private boolean diversityCheck(int candidate, float score, NeighborArray neighbors)
|
||||||
float[] candidate,
|
throws IOException {
|
||||||
float score,
|
return isDiverse(candidate, neighbors, score);
|
||||||
NeighborArray neighbors,
|
}
|
||||||
RandomAccessVectorValues vectorValues)
|
|
||||||
|
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||||
|
throws IOException {
|
||||||
|
return switch (vectorEncoding) {
|
||||||
|
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
|
||||||
|
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isDiverse(float[] candidate, NeighborArray neighbors, float score)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
for (int i = 0; i < neighbors.size(); i++) {
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
float neighborSimilarity =
|
float neighborSimilarity =
|
||||||
similarityFunction.compare(candidate, vectorValues.vectorValue(neighbors.node[i]));
|
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||||
|
if (neighborSimilarity >= score) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isDiverse(BytesRef candidate, NeighborArray neighbors, float score)
|
||||||
|
throws IOException {
|
||||||
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
|
float neighborSimilarity =
|
||||||
|
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||||
if (neighborSimilarity >= score) {
|
if (neighborSimilarity >= score) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -262,24 +316,52 @@ public final class HnswGraphBuilder {
|
|||||||
* neighbours
|
* neighbours
|
||||||
*/
|
*/
|
||||||
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
private int findWorstNonDiverse(NeighborArray neighbors) throws IOException {
|
||||||
float minAcceptedSimilarity;
|
|
||||||
for (int i = neighbors.size() - 1; i > 0; i--) {
|
for (int i = neighbors.size() - 1; i > 0; i--) {
|
||||||
int cNode = neighbors.node[i];
|
if (isWorstNonDiverse(i, neighbors, neighbors.score[i])) {
|
||||||
float[] cVector = vectorValues.vectorValue(cNode);
|
return i;
|
||||||
minAcceptedSimilarity = neighbors.score[i];
|
|
||||||
// check the candidate against its better-scoring neighbors
|
|
||||||
for (int j = i - 1; j >= 0; j--) {
|
|
||||||
float neighborSimilarity =
|
|
||||||
similarityFunction.compare(cVector, buildVectors.vectorValue(neighbors.node[j]));
|
|
||||||
// node i is too similar to node j given its score relative to the base node
|
|
||||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return neighbors.size() - 1;
|
return neighbors.size() - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private boolean isWorstNonDiverse(
|
||||||
|
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
|
||||||
|
return switch (vectorEncoding) {
|
||||||
|
case BYTE -> isWorstNonDiverse(
|
||||||
|
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
|
||||||
|
case FLOAT32 -> isWorstNonDiverse(
|
||||||
|
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isWorstNonDiverse(
|
||||||
|
int candidateIndex, float[] candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||||
|
throws IOException {
|
||||||
|
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||||
|
float neighborSimilarity =
|
||||||
|
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||||
|
// node i is too similar to node j given its score relative to the base node
|
||||||
|
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean isWorstNonDiverse(
|
||||||
|
int candidateIndex, BytesRef candidate, NeighborArray neighbors, float minAcceptedSimilarity)
|
||||||
|
throws IOException {
|
||||||
|
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||||
|
float neighborSimilarity =
|
||||||
|
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||||
|
// node i is too similar to node j given its score relative to the base node
|
||||||
|
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
private static int getRandomGraphLevel(double ml, SplittableRandom random) {
|
||||||
double randDouble;
|
double randDouble;
|
||||||
do {
|
do {
|
||||||
|
@ -18,21 +18,28 @@
|
|||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.util.BitSet;
|
import org.apache.lucene.util.BitSet;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.SparseFixedBitSet;
|
import org.apache.lucene.util.SparseFixedBitSet;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
|
* Searches an HNSW graph to find nearest neighbors to a query vector. For more background on the
|
||||||
* search algorithm, see {@link HnswGraph}.
|
* search algorithm, see {@link HnswGraph}.
|
||||||
|
*
|
||||||
|
* @param <T> the type of query vector
|
||||||
*/
|
*/
|
||||||
public final class HnswGraphSearcher {
|
public class HnswGraphSearcher<T> {
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
|
private final VectorEncoding vectorEncoding;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
|
* Scratch data structures that are used in each {@link #searchLevel} call. These can be expensive
|
||||||
* to allocate, so they're cleared and reused across calls.
|
* to allocate, so they're cleared and reused across calls.
|
||||||
@ -49,7 +56,11 @@ public final class HnswGraphSearcher {
|
|||||||
* @param visited bit set that will track nodes that have already been visited
|
* @param visited bit set that will track nodes that have already been visited
|
||||||
*/
|
*/
|
||||||
public HnswGraphSearcher(
|
public HnswGraphSearcher(
|
||||||
VectorSimilarityFunction similarityFunction, NeighborQueue candidates, BitSet visited) {
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
NeighborQueue candidates,
|
||||||
|
BitSet visited) {
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
this.similarityFunction = similarityFunction;
|
this.similarityFunction = similarityFunction;
|
||||||
this.candidates = candidates;
|
this.candidates = candidates;
|
||||||
this.visited = visited;
|
this.visited = visited;
|
||||||
@ -73,13 +84,68 @@ public final class HnswGraphSearcher {
|
|||||||
float[] query,
|
float[] query,
|
||||||
int topK,
|
int topK,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues vectors,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
HnswGraph graph,
|
HnswGraph graph,
|
||||||
Bits acceptOrds,
|
Bits acceptOrds,
|
||||||
int visitedLimit)
|
int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
HnswGraphSearcher graphSearcher =
|
if (query.length != vectors.dimension()) {
|
||||||
new HnswGraphSearcher(
|
throw new IllegalArgumentException(
|
||||||
|
"vector query dimension: "
|
||||||
|
+ query.length
|
||||||
|
+ " differs from field dimension: "
|
||||||
|
+ vectors.dimension());
|
||||||
|
}
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
return search(
|
||||||
|
toBytesRef(query),
|
||||||
|
topK,
|
||||||
|
vectors,
|
||||||
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
|
graph,
|
||||||
|
acceptOrds,
|
||||||
|
visitedLimit);
|
||||||
|
}
|
||||||
|
HnswGraphSearcher<float[]> graphSearcher =
|
||||||
|
new HnswGraphSearcher<>(
|
||||||
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
|
new NeighborQueue(topK, true),
|
||||||
|
new SparseFixedBitSet(vectors.size()));
|
||||||
|
NeighborQueue results;
|
||||||
|
int[] eps = new int[] {graph.entryNode()};
|
||||||
|
int numVisited = 0;
|
||||||
|
for (int level = graph.numLevels() - 1; level >= 1; level--) {
|
||||||
|
results = graphSearcher.searchLevel(query, 1, level, eps, vectors, graph, null, visitedLimit);
|
||||||
|
numVisited += results.visitedCount();
|
||||||
|
visitedLimit -= results.visitedCount();
|
||||||
|
if (results.incomplete()) {
|
||||||
|
results.setVisitedCount(numVisited);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
eps[0] = results.pop();
|
||||||
|
}
|
||||||
|
results =
|
||||||
|
graphSearcher.searchLevel(query, topK, 0, eps, vectors, graph, acceptOrds, visitedLimit);
|
||||||
|
results.setVisitedCount(results.visitedCount() + numVisited);
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static NeighborQueue search(
|
||||||
|
BytesRef query,
|
||||||
|
int topK,
|
||||||
|
RandomAccessVectorValues vectors,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction,
|
||||||
|
HnswGraph graph,
|
||||||
|
Bits acceptOrds,
|
||||||
|
int visitedLimit)
|
||||||
|
throws IOException {
|
||||||
|
HnswGraphSearcher<BytesRef> graphSearcher =
|
||||||
|
new HnswGraphSearcher<>(
|
||||||
|
vectorEncoding,
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
new NeighborQueue(topK, true),
|
new NeighborQueue(topK, true),
|
||||||
new SparseFixedBitSet(vectors.size()));
|
new SparseFixedBitSet(vectors.size()));
|
||||||
@ -119,7 +185,8 @@ public final class HnswGraphSearcher {
|
|||||||
* @return a priority queue holding the closest neighbors found
|
* @return a priority queue holding the closest neighbors found
|
||||||
*/
|
*/
|
||||||
public NeighborQueue searchLevel(
|
public NeighborQueue searchLevel(
|
||||||
float[] query,
|
// Note: this is only public because Lucene91HnswGraphBuilder needs it
|
||||||
|
T query,
|
||||||
int topK,
|
int topK,
|
||||||
int level,
|
int level,
|
||||||
final int[] eps,
|
final int[] eps,
|
||||||
@ -130,7 +197,7 @@ public final class HnswGraphSearcher {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private NeighborQueue searchLevel(
|
private NeighborQueue searchLevel(
|
||||||
float[] query,
|
T query,
|
||||||
int topK,
|
int topK,
|
||||||
int level,
|
int level,
|
||||||
final int[] eps,
|
final int[] eps,
|
||||||
@ -150,7 +217,7 @@ public final class HnswGraphSearcher {
|
|||||||
results.markIncomplete();
|
results.markIncomplete();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
float score = similarityFunction.compare(query, vectors.vectorValue(ep));
|
float score = compare(query, vectors, ep);
|
||||||
numVisited++;
|
numVisited++;
|
||||||
candidates.add(ep, score);
|
candidates.add(ep, score);
|
||||||
if (acceptOrds == null || acceptOrds.get(ep)) {
|
if (acceptOrds == null || acceptOrds.get(ep)) {
|
||||||
@ -185,7 +252,7 @@ public final class HnswGraphSearcher {
|
|||||||
results.markIncomplete();
|
results.markIncomplete();
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
float friendSimilarity = similarityFunction.compare(query, vectors.vectorValue(friendOrd));
|
float friendSimilarity = compare(query, vectors, friendOrd);
|
||||||
numVisited++;
|
numVisited++;
|
||||||
if (friendSimilarity >= minAcceptedSimilarity) {
|
if (friendSimilarity >= minAcceptedSimilarity) {
|
||||||
candidates.add(friendOrd, friendSimilarity);
|
candidates.add(friendOrd, friendSimilarity);
|
||||||
@ -204,6 +271,14 @@ public final class HnswGraphSearcher {
|
|||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
|
||||||
|
} else {
|
||||||
|
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private void prepareScratchState(int capacity) {
|
private void prepareScratchState(int capacity) {
|
||||||
candidates.clear();
|
candidates.clear();
|
||||||
if (visited.length() < capacity) {
|
if (visited.length() < capacity) {
|
||||||
|
@ -178,7 +178,7 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||||||
return new KnnVectorsWriter() {
|
return new KnnVectorsWriter() {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
fieldsWritten.add(fieldInfo.name);
|
fieldsWritten.add(fieldInfo.name);
|
||||||
return writer.addField(fieldInfo);
|
return writer.addField(fieldInfo);
|
||||||
}
|
}
|
||||||
|
@ -112,6 +112,7 @@ public class TestCodecs extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false));
|
false));
|
||||||
}
|
}
|
||||||
|
@ -260,6 +260,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false));
|
false));
|
||||||
}
|
}
|
||||||
@ -279,6 +280,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false));
|
false));
|
||||||
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
|
assertEquals("Field numbers 0 through 9 were allocated", 10, idx);
|
||||||
@ -300,6 +302,7 @@ public class TestFieldInfos extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false));
|
false));
|
||||||
assertEquals("Field numbers should reset after clear()", 0, idx);
|
assertEquals("Field numbers should reset after clear()", 0, idx);
|
||||||
|
@ -64,6 +64,7 @@ public class TestFieldsReader extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
field.name().equals(softDeletesFieldName)));
|
field.name().equals(softDeletesFieldName)));
|
||||||
}
|
}
|
||||||
|
@ -113,6 +113,11 @@ public class TestIndexableField extends LuceneTestCase {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public VectorEncoding vectorEncoding() {
|
||||||
|
return VectorEncoding.FLOAT32;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public VectorSimilarityFunction vectorSimilarityFunction() {
|
public VectorSimilarityFunction vectorSimilarityFunction() {
|
||||||
return VectorSimilarityFunction.EUCLIDEAN;
|
return VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
@ -67,6 +67,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
|
private static int M = Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||||
|
|
||||||
private Codec codec;
|
private Codec codec;
|
||||||
|
private Codec float32Codec;
|
||||||
|
private VectorEncoding vectorEncoding;
|
||||||
private VectorSimilarityFunction similarityFunction;
|
private VectorSimilarityFunction similarityFunction;
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
@ -86,6 +88,31 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
|
|
||||||
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
||||||
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
||||||
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
|
||||||
|
codec =
|
||||||
|
new Lucene94Codec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if (vectorEncoding == VectorEncoding.FLOAT32) {
|
||||||
|
float32Codec = codec;
|
||||||
|
} else {
|
||||||
|
float32Codec =
|
||||||
|
new Lucene94Codec() {
|
||||||
|
@Override
|
||||||
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
|
return new Lucene94HnswVectorsFormat(M, Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorEncoding randomVectorEncoding() {
|
||||||
|
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||||
}
|
}
|
||||||
|
|
||||||
@After
|
@After
|
||||||
@ -102,10 +129,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
float[][] values = new float[numDoc][];
|
float[][] values = new float[numDoc][];
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
if (random().nextBoolean()) {
|
if (random().nextBoolean()) {
|
||||||
values[i] = new float[dimension];
|
values[i] = randomVector(dimension);
|
||||||
for (int j = 0; j < dimension; j++) {
|
|
||||||
values[i][j] = random().nextFloat();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
add(iw, i, values[i]);
|
add(iw, i, values[i]);
|
||||||
}
|
}
|
||||||
@ -117,6 +141,14 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
try (Directory dir = newDirectory();
|
try (Directory dir = newDirectory();
|
||||||
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
IndexWriter iw = new IndexWriter(dir, newIndexWriterConfig(null).setCodec(codec))) {
|
||||||
float[][] values = new float[][] {new float[] {0, 1, 2}};
|
float[][] values = new float[][] {new float[] {0, 1, 2}};
|
||||||
|
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||||
|
VectorUtil.l2normalize(values[0]);
|
||||||
|
}
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
for (int i = 0; i < 3; i++) {
|
||||||
|
values[0][i] = (float) Math.floor(values[0][i] * 127);
|
||||||
|
}
|
||||||
|
}
|
||||||
add(iw, 0, values[0]);
|
add(iw, 0, values[0]);
|
||||||
assertConsistentGraph(iw, values);
|
assertConsistentGraph(iw, values);
|
||||||
iw.commit();
|
iw.commit();
|
||||||
@ -133,11 +165,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
float[][] values = randomVectors(numDoc, dimension);
|
float[][] values = randomVectors(numDoc, dimension);
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
if (random().nextBoolean()) {
|
if (random().nextBoolean()) {
|
||||||
values[i] = new float[dimension];
|
values[i] = randomVector(dimension);
|
||||||
for (int j = 0; j < dimension; j++) {
|
|
||||||
values[i][j] = random().nextFloat();
|
|
||||||
}
|
|
||||||
VectorUtil.l2normalize(values[i]);
|
|
||||||
}
|
}
|
||||||
add(iw, i, values[i]);
|
add(iw, i, values[i]);
|
||||||
if (random().nextInt(10) == 3) {
|
if (random().nextInt(10) == 3) {
|
||||||
@ -249,16 +277,26 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
float[][] values = new float[numDoc][];
|
float[][] values = new float[numDoc][];
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
if (random().nextBoolean()) {
|
if (random().nextBoolean()) {
|
||||||
values[i] = new float[dimension];
|
values[i] = randomVector(dimension);
|
||||||
for (int j = 0; j < dimension; j++) {
|
|
||||||
values[i][j] = random().nextFloat();
|
|
||||||
}
|
|
||||||
VectorUtil.l2normalize(values[i]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private float[] randomVector(int dimension) {
|
||||||
|
float[] value = new float[dimension];
|
||||||
|
for (int j = 0; j < dimension; j++) {
|
||||||
|
value[j] = random().nextFloat();
|
||||||
|
}
|
||||||
|
VectorUtil.l2normalize(value);
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
for (int j = 0; j < dimension; j++) {
|
||||||
|
value[j] = (byte) (value[j] * 127);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
int[][][] copyGraph(HnswGraph graphValues) throws IOException {
|
||||||
int[][][] graph = new int[graphValues.numLevels()][][];
|
int[][][] graph = new int[graphValues.numLevels()][][];
|
||||||
int size = graphValues.size();
|
int size = graphValues.size();
|
||||||
@ -285,7 +323,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
// We can't use dot product here since the vectors are laid out on a grid, not a sphere.
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
IndexWriterConfig config = newIndexWriterConfig();
|
IndexWriterConfig config = newIndexWriterConfig();
|
||||||
config.setCodec(codec); // test is not compatible with simpletext
|
config.setCodec(float32Codec);
|
||||||
try (Directory dir = newDirectory();
|
try (Directory dir = newDirectory();
|
||||||
IndexWriter iw = new IndexWriter(dir, config)) {
|
IndexWriter iw = new IndexWriter(dir, config)) {
|
||||||
indexData(iw);
|
indexData(iw);
|
||||||
@ -341,7 +379,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
public void testMultiThreadedSearch() throws Exception {
|
public void testMultiThreadedSearch() throws Exception {
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
IndexWriterConfig config = newIndexWriterConfig();
|
IndexWriterConfig config = newIndexWriterConfig();
|
||||||
config.setCodec(codec);
|
config.setCodec(float32Codec);
|
||||||
Directory dir = newDirectory();
|
Directory dir = newDirectory();
|
||||||
IndexWriter iw = new IndexWriter(dir, config);
|
IndexWriter iw = new IndexWriter(dir, config);
|
||||||
indexData(iw);
|
indexData(iw);
|
||||||
@ -468,7 +506,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||||||
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
|
"vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch),
|
||||||
values[id],
|
values[id],
|
||||||
scratch,
|
scratch,
|
||||||
0f);
|
0);
|
||||||
numDocsWithVectors++;
|
numDocsWithVectors++;
|
||||||
}
|
}
|
||||||
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()
|
// if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc()
|
||||||
|
@ -196,6 +196,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
|
List<Integer> docsDeleted = Arrays.asList(1, 3, 7, 8, DocIdSetIterator.NO_MORE_DOCS);
|
||||||
@ -233,6 +234,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
for (DocValuesFieldUpdates update : updates) {
|
for (DocValuesFieldUpdates update : updates) {
|
||||||
@ -295,6 +297,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
|
List<Integer> docsDeleted = Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS);
|
||||||
@ -362,6 +365,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
List<DocValuesFieldUpdates> updates =
|
List<DocValuesFieldUpdates> updates =
|
||||||
@ -398,6 +402,7 @@ public class TestPendingSoftDeletes extends TestPendingDeletes {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
true);
|
true);
|
||||||
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));
|
updates = Arrays.asList(singleUpdate(Arrays.asList(1, DocIdSetIterator.NO_MORE_DOCS), 3, true));
|
||||||
|
@ -25,6 +25,7 @@ import java.util.concurrent.LinkedBlockingQueue;
|
|||||||
import java.util.concurrent.ThreadPoolExecutor;
|
import java.util.concurrent.ThreadPoolExecutor;
|
||||||
import java.util.concurrent.TimeUnit;
|
import java.util.concurrent.TimeUnit;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.MatchAllDocsQuery;
|
import org.apache.lucene.search.MatchAllDocsQuery;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
@ -117,6 +118,12 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void doClose() {}
|
protected void doClose() {}
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ package org.apache.lucene.search;
|
|||||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
||||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||||
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
|
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
||||||
|
|
||||||
@ -40,7 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
|||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
@ -48,6 +49,7 @@ import org.apache.lucene.tests.util.TestUtil;
|
|||||||
import org.apache.lucene.util.BitSet;
|
import org.apache.lucene.util.BitSet;
|
||||||
import org.apache.lucene.util.BitSetIterator;
|
import org.apache.lucene.util.BitSetIterator;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
@ -174,7 +176,7 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
|
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
|
||||||
IllegalArgumentException e =
|
IllegalArgumentException e =
|
||||||
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
||||||
assertEquals("vector dimensions differ: 1!=2", e.getMessage());
|
assertEquals("vector query dimension: 1 differs from field dimension: 2", e.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -239,43 +241,38 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void testScoreEuclidean() throws IOException {
|
public void testScoreEuclidean() throws IOException {
|
||||||
try (Directory d = newDirectory()) {
|
float[][] vectors = new float[5][];
|
||||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
for (int j = 0; j < 5; j++) {
|
||||||
for (int j = 0; j < 5; j++) {
|
vectors[j] = new float[] {j, j};
|
||||||
Document doc = new Document();
|
}
|
||||||
doc.add(
|
try (Directory d = getIndexStore("field", vectors);
|
||||||
new KnnVectorField("field", new float[] {j, j}, VectorSimilarityFunction.EUCLIDEAN));
|
IndexReader reader = DirectoryReader.open(d)) {
|
||||||
w.addDocument(doc);
|
assertEquals(1, reader.leaves().size());
|
||||||
}
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
}
|
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
Query rewritten = query.rewrite(reader);
|
||||||
assertEquals(1, reader.leaves().size());
|
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
|
||||||
Query rewritten = query.rewrite(reader);
|
|
||||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
|
||||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
|
||||||
|
|
||||||
// prior to advancing, score is 0
|
// prior to advancing, score is 0
|
||||||
assertEquals(-1, scorer.docID());
|
assertEquals(-1, scorer.docID());
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
|
|
||||||
// test getMaxScore
|
// test getMaxScore
|
||||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
assertEquals(0, scorer.getMaxScore(0), 0);
|
||||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||||
|
|
||||||
DocIdSetIterator it = scorer.iterator();
|
DocIdSetIterator it = scorer.iterator();
|
||||||
assertEquals(3, it.cost());
|
assertEquals(3, it.cost());
|
||||||
assertEquals(1, it.nextDoc());
|
assertEquals(1, it.nextDoc());
|
||||||
assertEquals(1 / 6f, scorer.score(), 0);
|
assertEquals(1 / 6f, scorer.score(), 0);
|
||||||
assertEquals(3, it.advance(3));
|
assertEquals(3, it.advance(3));
|
||||||
assertEquals(1 / 2f, scorer.score(), 0);
|
assertEquals(1 / 2f, scorer.score(), 0);
|
||||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -764,9 +761,18 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||||
Directory indexStore = newDirectory();
|
Directory indexStore = newDirectory();
|
||||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||||
|
VectorEncoding encoding = randomVectorEncoding();
|
||||||
for (int i = 0; i < contents.length; ++i) {
|
for (int i = 0; i < contents.length; ++i) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new KnnVectorField(field, contents[i]));
|
if (encoding == VectorEncoding.BYTE) {
|
||||||
|
BytesRef v = new BytesRef(new byte[contents[i].length]);
|
||||||
|
for (int j = 0; j < v.length; j++) {
|
||||||
|
v.bytes[j] = (byte) contents[i][j];
|
||||||
|
}
|
||||||
|
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
|
||||||
|
} else {
|
||||||
|
doc.add(new KnnVectorField(field, contents[i]));
|
||||||
|
}
|
||||||
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
doc.add(new StringField("id", "id" + i, Field.Store.YES));
|
||||||
writer.addDocument(doc);
|
writer.addDocument(doc);
|
||||||
}
|
}
|
||||||
@ -908,4 +914,8 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||||||
return 31 * classHash() + docs.hashCode();
|
return 31 * classHash() + docs.hashCode();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private VectorEncoding randomVectorEncoding() {
|
||||||
|
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -50,6 +50,7 @@ import org.apache.lucene.index.LeafReader;
|
|||||||
import org.apache.lucene.index.PointValues;
|
import org.apache.lucene.index.PointValues;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
import org.apache.lucene.index.Terms;
|
import org.apache.lucene.index.Terms;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.search.SortField.Type;
|
import org.apache.lucene.search.SortField.Type;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
@ -1126,6 +1127,7 @@ public class TestSortOptimization extends LuceneTestCase {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.DOT_PRODUCT,
|
VectorSimilarityFunction.DOT_PRODUCT,
|
||||||
fi.isSoftDeletesField());
|
fi.isSoftDeletesField());
|
||||||
newInfos[i] = noIndexFI;
|
newInfos[i] = noIndexFI;
|
||||||
|
@ -18,6 +18,7 @@ package org.apache.lucene.util;
|
|||||||
|
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
|
|
||||||
public class TestVectorUtil extends LuceneTestCase {
|
public class TestVectorUtil extends LuceneTestCase {
|
||||||
|
|
||||||
@ -130,6 +131,23 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||||||
return u;
|
return u;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static BytesRef negative(BytesRef v) {
|
||||||
|
BytesRef u = new BytesRef(new byte[v.length]);
|
||||||
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
// what is (byte) -(-128)? 127?
|
||||||
|
u.bytes[i] = (byte) -v.bytes[i];
|
||||||
|
}
|
||||||
|
return u;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static float l2(BytesRef v) {
|
||||||
|
float l2 = 0;
|
||||||
|
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||||
|
l2 += v.bytes[i] * v.bytes[i];
|
||||||
|
}
|
||||||
|
return l2;
|
||||||
|
}
|
||||||
|
|
||||||
private static float[] randomVector() {
|
private static float[] randomVector() {
|
||||||
return randomVector(random().nextInt(100) + 1);
|
return randomVector(random().nextInt(100) + 1);
|
||||||
}
|
}
|
||||||
@ -142,4 +160,88 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static BytesRef randomVectorBytes() {
|
||||||
|
BytesRef v = TestUtil.randomBinaryTerm(random(), TestUtil.nextInt(random(), 1, 100));
|
||||||
|
// clip at -127 to avoid overflow
|
||||||
|
for (int i = v.offset; i < v.offset + v.length; i++) {
|
||||||
|
if (v.bytes[i] == -128) {
|
||||||
|
v.bytes[i] = -127;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testBasicDotProductBytes() {
|
||||||
|
BytesRef a = new BytesRef(new byte[] {1, 2, 3});
|
||||||
|
BytesRef b = new BytesRef(new byte[] {-10, 0, 5});
|
||||||
|
assertEquals(5, VectorUtil.dotProduct(a, b), 0);
|
||||||
|
assertEquals(5 / (3f * (1 << 15)), VectorUtil.dotProductScore(a, b), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSelfDotProductBytes() {
|
||||||
|
// the dot product of a vector with itself is equal to the sum of the squares of its components
|
||||||
|
BytesRef v = randomVectorBytes();
|
||||||
|
assertEquals(l2(v), VectorUtil.dotProduct(v, v), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOrthogonalDotProductBytes() {
|
||||||
|
// the dot product of two perpendicular vectors is 0
|
||||||
|
byte[] v = new byte[4];
|
||||||
|
v[0] = (byte) random().nextInt(100);
|
||||||
|
v[1] = (byte) random().nextInt(100);
|
||||||
|
v[2] = v[1];
|
||||||
|
v[3] = (byte) -v[0];
|
||||||
|
// also test computing using BytesRef with nonzero offset
|
||||||
|
assertEquals(0, VectorUtil.dotProduct(new BytesRef(v, 0, 2), new BytesRef(v, 2, 2)), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSelfSquareDistanceBytes() {
|
||||||
|
// the l2 distance of a vector with itself is zero
|
||||||
|
BytesRef v = randomVectorBytes();
|
||||||
|
assertEquals(0, VectorUtil.squareDistance(v, v), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testBasicSquareDistanceBytes() {
|
||||||
|
assertEquals(
|
||||||
|
12,
|
||||||
|
VectorUtil.squareDistance(
|
||||||
|
new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-1, 0, 5})),
|
||||||
|
0);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testRandomSquareDistanceBytes() {
|
||||||
|
// the square distance of a vector with its inverse is equal to four times the sum of squares of
|
||||||
|
// its components
|
||||||
|
BytesRef v = randomVectorBytes();
|
||||||
|
BytesRef u = negative(v);
|
||||||
|
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testBasicCosineBytes() {
|
||||||
|
assertEquals(
|
||||||
|
0.11952f,
|
||||||
|
VectorUtil.cosine(new BytesRef(new byte[] {1, 2, 3}), new BytesRef(new byte[] {-10, 0, 5})),
|
||||||
|
DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSelfCosineBytes() {
|
||||||
|
// the dot product of a vector with itself is always equal to 1
|
||||||
|
BytesRef v = randomVectorBytes();
|
||||||
|
// ensure the vector is non-zero so that cosine is defined
|
||||||
|
v.bytes[0] = (byte) (random().nextInt(126) + 1);
|
||||||
|
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testOrthogonalCosineBytes() {
|
||||||
|
// the cosine of two perpendicular vectors is 0
|
||||||
|
float[] v = new float[2];
|
||||||
|
v[0] = random().nextInt(100);
|
||||||
|
// ensure the vector is non-zero so that cosine is defined
|
||||||
|
v[1] = random().nextInt(1, 100);
|
||||||
|
float[] u = new float[2];
|
||||||
|
u[0] = v[1];
|
||||||
|
u[1] = -v[0];
|
||||||
|
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -56,6 +56,7 @@ import org.apache.lucene.index.LeafReader;
|
|||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.search.ConstantScoreScorer;
|
import org.apache.lucene.search.ConstantScoreScorer;
|
||||||
import org.apache.lucene.search.ConstantScoreWeight;
|
import org.apache.lucene.search.ConstantScoreWeight;
|
||||||
@ -101,6 +102,7 @@ public class KnnGraphTester {
|
|||||||
private int beamWidth;
|
private int beamWidth;
|
||||||
private int maxConn;
|
private int maxConn;
|
||||||
private VectorSimilarityFunction similarityFunction;
|
private VectorSimilarityFunction similarityFunction;
|
||||||
|
private VectorEncoding vectorEncoding;
|
||||||
private FixedBitSet matchDocs;
|
private FixedBitSet matchDocs;
|
||||||
private float selectivity;
|
private float selectivity;
|
||||||
private boolean prefilter;
|
private boolean prefilter;
|
||||||
@ -113,6 +115,7 @@ public class KnnGraphTester {
|
|||||||
topK = 100;
|
topK = 100;
|
||||||
fanout = topK;
|
fanout = topK;
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
|
vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
selectivity = 1f;
|
selectivity = 1f;
|
||||||
prefilter = false;
|
prefilter = false;
|
||||||
}
|
}
|
||||||
@ -195,12 +198,30 @@ public class KnnGraphTester {
|
|||||||
case "-docs":
|
case "-docs":
|
||||||
docVectorsPath = Paths.get(args[++iarg]);
|
docVectorsPath = Paths.get(args[++iarg]);
|
||||||
break;
|
break;
|
||||||
|
case "-encoding":
|
||||||
|
String encoding = args[++iarg];
|
||||||
|
switch (encoding) {
|
||||||
|
case "byte":
|
||||||
|
vectorEncoding = VectorEncoding.BYTE;
|
||||||
|
break;
|
||||||
|
case "float32":
|
||||||
|
vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("-encoding can be 'byte' or 'float32' only");
|
||||||
|
}
|
||||||
|
break;
|
||||||
case "-metric":
|
case "-metric":
|
||||||
String metric = args[++iarg];
|
String metric = args[++iarg];
|
||||||
if (metric.equals("euclidean")) {
|
switch (metric) {
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
case "euclidean":
|
||||||
} else if (metric.equals("angular") == false) {
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
|
break;
|
||||||
|
case "angular":
|
||||||
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("-metric can be 'angular' or 'euclidean' only");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
case "-forceMerge":
|
case "-forceMerge":
|
||||||
@ -229,7 +250,7 @@ public class KnnGraphTester {
|
|||||||
if (operation == null && reindex == false) {
|
if (operation == null && reindex == false) {
|
||||||
usage();
|
usage();
|
||||||
}
|
}
|
||||||
if (prefilter == true && selectivity == 1f) {
|
if (prefilter && selectivity == 1f) {
|
||||||
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
|
throw new IllegalArgumentException("-prefilter requires filterSelectivity between 0 and 1");
|
||||||
}
|
}
|
||||||
indexPath = Paths.get(formatIndexPath(docVectorsPath));
|
indexPath = Paths.get(formatIndexPath(docVectorsPath));
|
||||||
@ -248,7 +269,9 @@ public class KnnGraphTester {
|
|||||||
if (docVectorsPath == null) {
|
if (docVectorsPath == null) {
|
||||||
throw new IllegalArgumentException("missing -docs arg");
|
throw new IllegalArgumentException("missing -docs arg");
|
||||||
}
|
}
|
||||||
matchDocs = generateRandomBitSet(numDocs, selectivity);
|
if (selectivity < 1) {
|
||||||
|
matchDocs = generateRandomBitSet(numDocs, selectivity);
|
||||||
|
}
|
||||||
if (outputPath != null) {
|
if (outputPath != null) {
|
||||||
testSearch(indexPath, queryPath, outputPath, null);
|
testSearch(indexPath, queryPath, outputPath, null);
|
||||||
} else {
|
} else {
|
||||||
@ -285,14 +308,17 @@ public class KnnGraphTester {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
private void dumpGraph(Path docsPath) throws IOException {
|
private void dumpGraph(Path docsPath) throws IOException {
|
||||||
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
||||||
RandomAccessVectorValues values = vectors.randomAccess();
|
RandomAccessVectorValues values = vectors.randomAccess();
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<float[]> builder =
|
||||||
new HnswGraphBuilder(vectors, similarityFunction, maxConn, beamWidth, 0);
|
(HnswGraphBuilder<float[]>)
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
|
||||||
// start at node 1
|
// start at node 1
|
||||||
for (int i = 1; i < numDocs; i++) {
|
for (int i = 1; i < numDocs; i++) {
|
||||||
builder.addGraphNode(i, values.vectorValue(i));
|
builder.addGraphNode(i, values);
|
||||||
System.out.println("\nITERATION " + i);
|
System.out.println("\nITERATION " + i);
|
||||||
dumpGraph(builder.hnsw);
|
dumpGraph(builder.hnsw);
|
||||||
}
|
}
|
||||||
@ -375,13 +401,8 @@ public class KnnGraphTester {
|
|||||||
throws IOException {
|
throws IOException {
|
||||||
TopDocs[] results = new TopDocs[numIters];
|
TopDocs[] results = new TopDocs[numIters];
|
||||||
long elapsed, totalCpuTime, totalVisited = 0;
|
long elapsed, totalCpuTime, totalVisited = 0;
|
||||||
try (FileChannel q = FileChannel.open(queryPath)) {
|
try (FileChannel input = FileChannel.open(queryPath)) {
|
||||||
int bufferSize = numIters * dim * Float.BYTES;
|
VectorReader targetReader = VectorReader.create(input, dim, vectorEncoding, numIters);
|
||||||
FloatBuffer targets =
|
|
||||||
q.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize)
|
|
||||||
.order(ByteOrder.LITTLE_ENDIAN)
|
|
||||||
.asFloatBuffer();
|
|
||||||
float[] target = new float[dim];
|
|
||||||
if (quiet == false) {
|
if (quiet == false) {
|
||||||
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
|
System.out.println("running " + numIters + " targets; topK=" + topK + ", fanout=" + fanout);
|
||||||
}
|
}
|
||||||
@ -392,21 +413,21 @@ public class KnnGraphTester {
|
|||||||
DirectoryReader reader = DirectoryReader.open(dir)) {
|
DirectoryReader reader = DirectoryReader.open(dir)) {
|
||||||
IndexSearcher searcher = new IndexSearcher(reader);
|
IndexSearcher searcher = new IndexSearcher(reader);
|
||||||
numDocs = reader.maxDoc();
|
numDocs = reader.maxDoc();
|
||||||
Query bitSetQuery = new BitSetQuery(matchDocs);
|
Query bitSetQuery = prefilter ? new BitSetQuery(matchDocs) : null;
|
||||||
for (int i = 0; i < numIters; i++) {
|
for (int i = 0; i < numIters; i++) {
|
||||||
// warm up
|
// warm up
|
||||||
targets.get(target);
|
float[] target = targetReader.next();
|
||||||
if (prefilter) {
|
if (prefilter) {
|
||||||
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||||
} else {
|
} else {
|
||||||
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
doKnnVectorQuery(searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
targets.position(0);
|
targetReader.reset();
|
||||||
start = System.nanoTime();
|
start = System.nanoTime();
|
||||||
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
|
cpuTimeStartNs = bean.getCurrentThreadCpuTime();
|
||||||
for (int i = 0; i < numIters; i++) {
|
for (int i = 0; i < numIters; i++) {
|
||||||
targets.get(target);
|
float[] target = targetReader.next();
|
||||||
if (prefilter) {
|
if (prefilter) {
|
||||||
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, bitSetQuery);
|
||||||
} else {
|
} else {
|
||||||
@ -414,10 +435,12 @@ public class KnnGraphTester {
|
|||||||
doKnnVectorQuery(
|
doKnnVectorQuery(
|
||||||
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
searcher, KNN_FIELD, target, (int) (topK / selectivity), fanout, null);
|
||||||
|
|
||||||
results[i].scoreDocs =
|
if (matchDocs != null) {
|
||||||
Arrays.stream(results[i].scoreDocs)
|
results[i].scoreDocs =
|
||||||
.filter(scoreDoc -> matchDocs == null || matchDocs.get(scoreDoc.doc))
|
Arrays.stream(results[i].scoreDocs)
|
||||||
.toArray(ScoreDoc[]::new);
|
.filter(scoreDoc -> matchDocs.get(scoreDoc.doc))
|
||||||
|
.toArray(ScoreDoc[]::new);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
|
totalCpuTime = (bean.getCurrentThreadCpuTime() - cpuTimeStartNs) / 1_000_000;
|
||||||
@ -425,7 +448,14 @@ public class KnnGraphTester {
|
|||||||
for (int i = 0; i < numIters; i++) {
|
for (int i = 0; i < numIters; i++) {
|
||||||
totalVisited += results[i].totalHits.value;
|
totalVisited += results[i].totalHits.value;
|
||||||
for (ScoreDoc doc : results[i].scoreDocs) {
|
for (ScoreDoc doc : results[i].scoreDocs) {
|
||||||
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
|
if (doc.doc != NO_MORE_DOCS) {
|
||||||
|
// there is a bug somewhere that can result in doc=NO_MORE_DOCS! I think it happens
|
||||||
|
// in some degenerate case (like input query has NaN in it?) that causes no results to
|
||||||
|
// be returned from HNSW search?
|
||||||
|
doc.doc = Integer.parseInt(reader.document(doc.doc).get("id"));
|
||||||
|
} else {
|
||||||
|
System.out.println("NO_MORE_DOCS!");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -477,6 +507,78 @@ public class KnnGraphTester {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private abstract static class VectorReader {
|
||||||
|
final float[] target;
|
||||||
|
final ByteBuffer bytes;
|
||||||
|
|
||||||
|
static VectorReader create(FileChannel input, int dim, VectorEncoding vectorEncoding, int n)
|
||||||
|
throws IOException {
|
||||||
|
int bufferSize = n * dim * vectorEncoding.byteSize;
|
||||||
|
return switch (vectorEncoding) {
|
||||||
|
case BYTE -> new VectorReaderByte(input, dim, bufferSize);
|
||||||
|
case FLOAT32 -> new VectorReaderFloat32(input, dim, bufferSize);
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
VectorReader(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||||
|
bytes =
|
||||||
|
input.map(FileChannel.MapMode.READ_ONLY, 0, bufferSize).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
target = new float[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
bytes.position(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract float[] next();
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class VectorReaderFloat32 extends VectorReader {
|
||||||
|
private final FloatBuffer floats;
|
||||||
|
|
||||||
|
VectorReaderFloat32(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||||
|
super(input, dim, bufferSize);
|
||||||
|
floats = bytes.asFloatBuffer();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
void reset() {
|
||||||
|
super.reset();
|
||||||
|
floats.position(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
float[] next() {
|
||||||
|
floats.get(target);
|
||||||
|
return target;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class VectorReaderByte extends VectorReader {
|
||||||
|
private byte[] scratch;
|
||||||
|
private BytesRef bytesRef;
|
||||||
|
|
||||||
|
VectorReaderByte(FileChannel input, int dim, int bufferSize) throws IOException {
|
||||||
|
super(input, dim, bufferSize);
|
||||||
|
scratch = new byte[dim];
|
||||||
|
bytesRef = new BytesRef(scratch);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
float[] next() {
|
||||||
|
bytes.get(scratch);
|
||||||
|
for (int i = 0; i < scratch.length; i++) {
|
||||||
|
target[i] = scratch[i];
|
||||||
|
}
|
||||||
|
return target;
|
||||||
|
}
|
||||||
|
|
||||||
|
BytesRef nextBytes() {
|
||||||
|
bytes.get(scratch);
|
||||||
|
return bytesRef;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private static TopDocs doKnnVectorQuery(
|
private static TopDocs doKnnVectorQuery(
|
||||||
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
|
IndexSearcher searcher, String field, float[] vector, int k, int fanout, Query filter)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
@ -529,7 +631,9 @@ public class KnnGraphTester {
|
|||||||
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
|
if (Files.exists(nnPath) && isNewer(nnPath, docPath, queryPath) && selectivity == 1f) {
|
||||||
return readNN(nnPath);
|
return readNN(nnPath);
|
||||||
} else {
|
} else {
|
||||||
int[][] nn = computeNN(docPath, queryPath);
|
// TODO: enable computing NN from high precision vectors when
|
||||||
|
// checking low-precision recall
|
||||||
|
int[][] nn = computeNN(docPath, queryPath, vectorEncoding);
|
||||||
if (selectivity == 1f) {
|
if (selectivity == 1f) {
|
||||||
writeNN(nn, nnPath);
|
writeNN(nn, nnPath);
|
||||||
}
|
}
|
||||||
@ -589,52 +693,37 @@ public class KnnGraphTester {
|
|||||||
return bitSet;
|
return bitSet;
|
||||||
}
|
}
|
||||||
|
|
||||||
private int[][] computeNN(Path docPath, Path queryPath) throws IOException {
|
private int[][] computeNN(Path docPath, Path queryPath, VectorEncoding encoding)
|
||||||
|
throws IOException {
|
||||||
int[][] result = new int[numIters][];
|
int[][] result = new int[numIters][];
|
||||||
if (quiet == false) {
|
if (quiet == false) {
|
||||||
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
System.out.println("computing true nearest neighbors of " + numIters + " target vectors");
|
||||||
}
|
}
|
||||||
try (FileChannel in = FileChannel.open(docPath);
|
try (FileChannel in = FileChannel.open(docPath);
|
||||||
FileChannel qIn = FileChannel.open(queryPath)) {
|
FileChannel qIn = FileChannel.open(queryPath)) {
|
||||||
FloatBuffer queries =
|
VectorReader docReader = VectorReader.create(in, dim, encoding, numDocs);
|
||||||
qIn.map(FileChannel.MapMode.READ_ONLY, 0, numIters * dim * Float.BYTES)
|
VectorReader queryReader = VectorReader.create(qIn, dim, encoding, numIters);
|
||||||
.order(ByteOrder.LITTLE_ENDIAN)
|
|
||||||
.asFloatBuffer();
|
|
||||||
float[] vector = new float[dim];
|
|
||||||
float[] query = new float[dim];
|
|
||||||
for (int i = 0; i < numIters; i++) {
|
for (int i = 0; i < numIters; i++) {
|
||||||
queries.get(query);
|
float[] query = queryReader.next();
|
||||||
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
NeighborQueue queue = new NeighborQueue(topK, false);
|
||||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
for (int j = 0; j < numDocs; j++) {
|
||||||
int offset = 0;
|
float[] doc = docReader.next();
|
||||||
int j = 0;
|
float d = similarityFunction.compare(query, doc);
|
||||||
// System.out.println("totalBytes=" + totalBytes);
|
if (matchDocs == null || matchDocs.get(j)) {
|
||||||
while (j < numDocs) {
|
queue.insertWithOverflow(j, d);
|
||||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
|
||||||
FloatBuffer vectors =
|
|
||||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
|
||||||
.order(ByteOrder.LITTLE_ENDIAN)
|
|
||||||
.asFloatBuffer();
|
|
||||||
offset += blockSize;
|
|
||||||
NeighborQueue queue = new NeighborQueue(topK, false);
|
|
||||||
for (; j < numDocs && vectors.hasRemaining(); j++) {
|
|
||||||
vectors.get(vector);
|
|
||||||
float d = similarityFunction.compare(query, vector);
|
|
||||||
if (matchDocs == null || matchDocs.get(j)) {
|
|
||||||
queue.insertWithOverflow(j, d);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result[i] = new int[topK];
|
|
||||||
for (int k = topK - 1; k >= 0; k--) {
|
|
||||||
result[i][k] = queue.topNode();
|
|
||||||
queue.pop();
|
|
||||||
// System.out.print(" " + n);
|
|
||||||
}
|
|
||||||
if (quiet == false && (i + 1) % 10 == 0) {
|
|
||||||
System.out.print(" " + (i + 1));
|
|
||||||
System.out.flush();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
docReader.reset();
|
||||||
|
result[i] = new int[topK];
|
||||||
|
for (int k = topK - 1; k >= 0; k--) {
|
||||||
|
result[i][k] = queue.topNode();
|
||||||
|
queue.pop();
|
||||||
|
// System.out.print(" " + n);
|
||||||
|
}
|
||||||
|
if (quiet == false && (i + 1) % 10 == 0) {
|
||||||
|
System.out.print(" " + (i + 1));
|
||||||
|
System.out.flush();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
@ -651,37 +740,29 @@ public class KnnGraphTester {
|
|||||||
});
|
});
|
||||||
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
// iwc.setMergePolicy(NoMergePolicy.INSTANCE);
|
||||||
iwc.setRAMBufferSizeMB(1994d);
|
iwc.setRAMBufferSizeMB(1994d);
|
||||||
|
iwc.setUseCompoundFile(false);
|
||||||
// iwc.setMaxBufferedDocs(10000);
|
// iwc.setMaxBufferedDocs(10000);
|
||||||
|
|
||||||
FieldType fieldType = KnnVectorField.createFieldType(dim, similarityFunction);
|
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
|
||||||
if (quiet == false) {
|
if (quiet == false) {
|
||||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||||
System.out.println("creating index in " + indexPath);
|
System.out.println("creating index in " + indexPath);
|
||||||
}
|
}
|
||||||
long start = System.nanoTime();
|
long start = System.nanoTime();
|
||||||
long totalBytes = (long) numDocs * dim * Float.BYTES, offset = 0;
|
|
||||||
final int maxBlockSize = (Integer.MAX_VALUE / (dim * Float.BYTES)) * (dim * Float.BYTES);
|
|
||||||
|
|
||||||
try (FSDirectory dir = FSDirectory.open(indexPath);
|
try (FSDirectory dir = FSDirectory.open(indexPath);
|
||||||
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
IndexWriter iw = new IndexWriter(dir, iwc)) {
|
||||||
float[] vector = new float[dim];
|
|
||||||
try (FileChannel in = FileChannel.open(docsPath)) {
|
try (FileChannel in = FileChannel.open(docsPath)) {
|
||||||
int i = 0;
|
VectorReader vectorReader = VectorReader.create(in, dim, vectorEncoding, numDocs);
|
||||||
while (i < numDocs) {
|
for (int i = 0; i < numDocs; i++) {
|
||||||
int blockSize = (int) Math.min(totalBytes - offset, maxBlockSize);
|
Document doc = new Document();
|
||||||
FloatBuffer vectors =
|
switch (vectorEncoding) {
|
||||||
in.map(FileChannel.MapMode.READ_ONLY, offset, blockSize)
|
case BYTE -> doc.add(
|
||||||
.order(ByteOrder.LITTLE_ENDIAN)
|
new KnnVectorField(
|
||||||
.asFloatBuffer();
|
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
|
||||||
offset += blockSize;
|
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
|
||||||
for (; vectors.hasRemaining() && i < numDocs; i++) {
|
|
||||||
vectors.get(vector);
|
|
||||||
Document doc = new Document();
|
|
||||||
// System.out.println("vector=" + vector[0] + "," + vector[1] + "...");
|
|
||||||
doc.add(new KnnVectorField(KNN_FIELD, vector, fieldType));
|
|
||||||
doc.add(new StoredField(ID_FIELD, i));
|
|
||||||
iw.addDocument(doc);
|
|
||||||
}
|
}
|
||||||
|
doc.add(new StoredField(ID_FIELD, i));
|
||||||
|
iw.addDocument(doc);
|
||||||
}
|
}
|
||||||
if (quiet == false) {
|
if (quiet == false) {
|
||||||
System.out.println("Done indexing " + numDocs + " documents; now flush");
|
System.out.println("Done indexing " + numDocs + " documents; now flush");
|
||||||
|
@ -31,6 +31,7 @@ class MockVectorValues extends VectorValues
|
|||||||
protected final float[][] denseValues;
|
protected final float[][] denseValues;
|
||||||
protected final float[][] values;
|
protected final float[][] values;
|
||||||
private final int numVectors;
|
private final int numVectors;
|
||||||
|
private final BytesRef binaryValue;
|
||||||
|
|
||||||
private int pos = -1;
|
private int pos = -1;
|
||||||
|
|
||||||
@ -47,6 +48,9 @@ class MockVectorValues extends VectorValues
|
|||||||
}
|
}
|
||||||
numVectors = count;
|
numVectors = count;
|
||||||
scratch = new float[dimension];
|
scratch = new float[dimension];
|
||||||
|
// used by tests that build a graph from bytes rather than floats
|
||||||
|
binaryValue = new BytesRef(dimension);
|
||||||
|
binaryValue.length = dimension;
|
||||||
}
|
}
|
||||||
|
|
||||||
public MockVectorValues copy() {
|
public MockVectorValues copy() {
|
||||||
@ -89,7 +93,11 @@ class MockVectorValues extends VectorValues
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) {
|
public BytesRef binaryValue(int targetOrd) {
|
||||||
return null;
|
float[] value = vectorValue(targetOrd);
|
||||||
|
for (int i = 0; i < value.length; i++) {
|
||||||
|
binaryValue.bytes[i] = (byte) value[i];
|
||||||
|
}
|
||||||
|
return binaryValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
private boolean seek(int target) {
|
private boolean seek(int target) {
|
||||||
|
@ -18,6 +18,7 @@
|
|||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@ -43,6 +44,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
|||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
@ -60,25 +62,38 @@ import org.apache.lucene.util.BytesRef;
|
|||||||
import org.apache.lucene.util.FixedBitSet;
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
/** Tests HNSW KNN graphs */
|
/** Tests HNSW KNN graphs */
|
||||||
public class TestHnswGraph extends LuceneTestCase {
|
public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
|
VectorSimilarityFunction similarityFunction;
|
||||||
|
VectorEncoding vectorEncoding;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
similarityFunction =
|
||||||
|
VectorSimilarityFunction.values()[
|
||||||
|
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
||||||
|
if (similarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||||
|
vectorEncoding =
|
||||||
|
VectorEncoding.values()[random().nextInt(VectorEncoding.values().length - 1) + 1];
|
||||||
|
} else {
|
||||||
|
vectorEncoding = VectorEncoding.FLOAT32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// test writing out and reading in a graph gives the expected graph
|
// test writing out and reading in a graph gives the expected graph
|
||||||
public void testReadWrite() throws IOException {
|
public void testReadWrite() throws IOException {
|
||||||
int dim = random().nextInt(100) + 1;
|
int dim = random().nextInt(100) + 1;
|
||||||
int nDoc = random().nextInt(100) + 1;
|
int nDoc = random().nextInt(100) + 1;
|
||||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
int M = random().nextInt(4) + 2;
|
||||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
|
||||||
|
|
||||||
int M = random().nextInt(10) + 5;
|
|
||||||
int beamWidth = random().nextInt(10) + 5;
|
int beamWidth = random().nextInt(10) + 5;
|
||||||
long seed = random().nextLong();
|
long seed = random().nextLong();
|
||||||
VectorSimilarityFunction similarityFunction =
|
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
|
||||||
VectorSimilarityFunction.values()[
|
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
HnswGraphBuilder<?> builder =
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||||
new HnswGraphBuilder(vectors, similarityFunction, M, beamWidth, seed);
|
|
||||||
HnswGraph hnsw = builder.build(vectors);
|
HnswGraph hnsw = builder.build(vectors);
|
||||||
|
|
||||||
// Recreate the graph while indexing with the same random seed and write it out
|
// Recreate the graph while indexing with the same random seed and write it out
|
||||||
@ -131,6 +146,10 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private VectorEncoding randomVectorEncoding() {
|
||||||
|
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||||
|
}
|
||||||
|
|
||||||
// test that sorted index returns the same search results are unsorted
|
// test that sorted index returns the same search results are unsorted
|
||||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||||
int dim = random().nextInt(10) + 3;
|
int dim = random().nextInt(10) + 3;
|
||||||
@ -250,24 +269,27 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
// oriented in the right directions
|
// oriented in the right directions
|
||||||
public void testAknnDiverse() throws IOException {
|
public void testAknnDiverse() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<?> builder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 10, 100, random().nextInt());
|
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
// run some searches
|
// run some searches
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
new float[] {1, 0},
|
getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.randomAccess(),
|
vectors.randomAccess(),
|
||||||
VectorSimilarityFunction.DOT_PRODUCT,
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
null,
|
null,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
|
|
||||||
int[] nodes = nn.nodes();
|
int[] nodes = nn.nodes();
|
||||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||||
int sum = 0;
|
int sum = 0;
|
||||||
for (int node : nodes) {
|
for (int node : nodes) {
|
||||||
sum += node;
|
sum += node;
|
||||||
@ -289,23 +311,26 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
public void testSearchWithAcceptOrds() throws IOException {
|
public void testSearchWithAcceptOrds() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||||
HnswGraphBuilder builder =
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
new HnswGraphBuilder(
|
vectorEncoding = randomVectorEncoding();
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
HnswGraphBuilder<?> builder =
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
// the first 10 docs must not be deleted to ensure the expected recall
|
// the first 10 docs must not be deleted to ensure the expected recall
|
||||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
new float[] {1, 0},
|
getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.randomAccess(),
|
vectors.randomAccess(),
|
||||||
VectorSimilarityFunction.DOT_PRODUCT,
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
int[] nodes = nn.nodes();
|
int[] nodes = nn.nodes();
|
||||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||||
int sum = 0;
|
int sum = 0;
|
||||||
for (int node : nodes) {
|
for (int node : nodes) {
|
||||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||||
@ -319,9 +344,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||||
HnswGraphBuilder builder =
|
vectorEncoding = randomVectorEncoding();
|
||||||
new HnswGraphBuilder(
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
HnswGraphBuilder<?> builder =
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
// Only mark a few vectors as accepted
|
// Only mark a few vectors as accepted
|
||||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||||
@ -333,10 +360,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
int numAccepted = acceptOrds.cardinality();
|
int numAccepted = acceptOrds.cardinality();
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
new float[] {1, 0},
|
getTargetVector(),
|
||||||
numAccepted,
|
numAccepted,
|
||||||
vectors.randomAccess(),
|
vectors.randomAccess(),
|
||||||
VectorSimilarityFunction.DOT_PRODUCT,
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
@ -347,12 +375,17 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private float[] getTargetVector() {
|
||||||
|
return new float[] {1, 0};
|
||||||
|
}
|
||||||
|
|
||||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||||
int nDoc = 1000;
|
int nDoc = 1000;
|
||||||
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<?> builder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
vectors, VectorSimilarityFunction.EUCLIDEAN, 16, 100, random().nextInt());
|
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
|
|
||||||
// Skip over half of the documents that are closest to the query vector
|
// Skip over half of the documents that are closest to the query vector
|
||||||
@ -362,15 +395,16 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
new float[] {1, 0},
|
getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.randomAccess(),
|
vectors.randomAccess(),
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorEncoding.FLOAT32,
|
||||||
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
int[] nodes = nn.nodes();
|
int[] nodes = nn.nodes();
|
||||||
assertTrue("Number of found results is not equal to [10].", nodes.length == 10);
|
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||||
int sum = 0;
|
int sum = 0;
|
||||||
for (int node : nodes) {
|
for (int node : nodes) {
|
||||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||||
@ -383,20 +417,23 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
|
|
||||||
public void testVisitedLimit() throws IOException {
|
public void testVisitedLimit() throws IOException {
|
||||||
int nDoc = 500;
|
int nDoc = 500;
|
||||||
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<?> builder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 16, 100, random().nextInt());
|
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
|
|
||||||
int topK = 50;
|
int topK = 50;
|
||||||
int visitedLimit = topK + random().nextInt(5);
|
int visitedLimit = topK + random().nextInt(5);
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
new float[] {1, 0},
|
getTargetVector(),
|
||||||
topK,
|
topK,
|
||||||
vectors.randomAccess(),
|
vectors.randomAccess(),
|
||||||
VectorSimilarityFunction.DOT_PRODUCT,
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
createRandomAcceptOrds(0, vectors.size),
|
createRandomAcceptOrds(0, vectors.size),
|
||||||
visitedLimit);
|
visitedLimit);
|
||||||
@ -406,54 +443,68 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void testHnswGraphBuilderInvalid() {
|
public void testHnswGraphBuilderInvalid() {
|
||||||
expectThrows(NullPointerException.class, () -> new HnswGraphBuilder(null, null, 0, 0, 0));
|
expectThrows(
|
||||||
|
NullPointerException.class, () -> HnswGraphBuilder.create(null, null, null, 0, 0, 0));
|
||||||
|
// M must be > 0
|
||||||
expectThrows(
|
expectThrows(
|
||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() ->
|
() ->
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
new RandomVectorValues(1, 1, random()),
|
new RandomVectorValues(1, 1, random()),
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
0,
|
0,
|
||||||
10,
|
10,
|
||||||
0));
|
0));
|
||||||
|
// beamWidth must be > 0
|
||||||
expectThrows(
|
expectThrows(
|
||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() ->
|
() ->
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
new RandomVectorValues(1, 1, random()),
|
new RandomVectorValues(1, 1, random()),
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
10,
|
10,
|
||||||
0,
|
0,
|
||||||
0));
|
0));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testDiversity() throws IOException {
|
public void testDiversity() throws IOException {
|
||||||
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
||||||
MockVectorValues vectors =
|
float[][] values = {
|
||||||
new MockVectorValues(
|
unitVector2d(0.5),
|
||||||
new float[][] {
|
unitVector2d(0.75),
|
||||||
unitVector2d(0.5),
|
unitVector2d(0.2),
|
||||||
unitVector2d(0.75),
|
unitVector2d(0.9),
|
||||||
unitVector2d(0.2),
|
unitVector2d(0.8),
|
||||||
unitVector2d(0.9),
|
unitVector2d(0.77),
|
||||||
unitVector2d(0.8),
|
};
|
||||||
unitVector2d(0.77),
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
});
|
for (float[] v : values) {
|
||||||
|
for (int i = 0; i < v.length; i++) {
|
||||||
|
v[i] *= 127;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MockVectorValues vectors = new MockVectorValues(values);
|
||||||
// First add nodes until everybody gets a full neighbor list
|
// First add nodes until everybody gets a full neighbor list
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<?> builder =
|
||||||
new HnswGraphBuilder(
|
HnswGraphBuilder.create(
|
||||||
vectors, VectorSimilarityFunction.DOT_PRODUCT, 2, 10, random().nextInt());
|
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
|
||||||
// node 0 is added by the builder constructor
|
// node 0 is added by the builder constructor
|
||||||
// builder.addGraphNode(vectors.vectorValue(0));
|
// builder.addGraphNode(vectors.vectorValue(0));
|
||||||
builder.addGraphNode(1, vectors.vectorValue(1));
|
builder.addGraphNode(1, vectors);
|
||||||
builder.addGraphNode(2, vectors.vectorValue(2));
|
builder.addGraphNode(2, vectors);
|
||||||
// now every node has tried to attach every other node as a neighbor, but
|
// now every node has tried to attach every other node as a neighbor, but
|
||||||
// some were excluded based on diversity check.
|
// some were excluded based on diversity check.
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||||
|
|
||||||
builder.addGraphNode(3, vectors.vectorValue(3));
|
builder.addGraphNode(3, vectors);
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
// we added 3 here
|
// we added 3 here
|
||||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
|
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
|
||||||
@ -461,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
assertLevel0Neighbors(builder.hnsw, 3, 1);
|
assertLevel0Neighbors(builder.hnsw, 3, 1);
|
||||||
|
|
||||||
// supplant an existing neighbor
|
// supplant an existing neighbor
|
||||||
builder.addGraphNode(4, vectors.vectorValue(4));
|
builder.addGraphNode(4, vectors);
|
||||||
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
|
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
|
||||||
@ -470,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
||||||
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
|
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
|
||||||
|
|
||||||
builder.addGraphNode(5, vectors.vectorValue(5));
|
builder.addGraphNode(5, vectors);
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
|
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
|
||||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||||
@ -494,29 +545,46 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
public void testRandom() throws IOException {
|
public void testRandom() throws IOException {
|
||||||
int size = atLeast(100);
|
int size = atLeast(100);
|
||||||
int dim = atLeast(10);
|
int dim = atLeast(10);
|
||||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, random());
|
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
|
||||||
VectorSimilarityFunction similarityFunction =
|
|
||||||
VectorSimilarityFunction.values()[
|
|
||||||
random().nextInt(VectorSimilarityFunction.values().length - 1) + 1];
|
|
||||||
int topK = 5;
|
int topK = 5;
|
||||||
HnswGraphBuilder builder =
|
HnswGraphBuilder<?> builder =
|
||||||
new HnswGraphBuilder(vectors, similarityFunction, 10, 30, random().nextLong());
|
HnswGraphBuilder.create(
|
||||||
|
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||||
|
|
||||||
int totalMatches = 0;
|
int totalMatches = 0;
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
float[] query = randomVector(random(), dim);
|
NeighborQueue actual;
|
||||||
NeighborQueue actual =
|
float[] query;
|
||||||
|
BytesRef bQuery = null;
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
query = randomVector8(random(), dim);
|
||||||
|
bQuery = toBytesRef(query);
|
||||||
|
} else {
|
||||||
|
query = randomVector(random(), dim);
|
||||||
|
}
|
||||||
|
actual =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
query, 100, vectors, similarityFunction, hnsw, acceptOrds, Integer.MAX_VALUE);
|
query,
|
||||||
|
100,
|
||||||
|
vectors,
|
||||||
|
vectorEncoding,
|
||||||
|
similarityFunction,
|
||||||
|
hnsw,
|
||||||
|
acceptOrds,
|
||||||
|
Integer.MAX_VALUE);
|
||||||
while (actual.size() > topK) {
|
while (actual.size() > topK) {
|
||||||
actual.pop();
|
actual.pop();
|
||||||
}
|
}
|
||||||
NeighborQueue expected = new NeighborQueue(topK, false);
|
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
|
||||||
|
} else {
|
||||||
|
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||||
|
}
|
||||||
if (expected.size() > topK) {
|
if (expected.size() > topK) {
|
||||||
expected.pop();
|
expected.pop();
|
||||||
}
|
}
|
||||||
@ -553,12 +621,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||||
private final int size;
|
private final int size;
|
||||||
private final float[] value;
|
private final float[] value;
|
||||||
|
private final BytesRef binaryValue;
|
||||||
|
|
||||||
int doc = -1;
|
int doc = -1;
|
||||||
|
|
||||||
CircularVectorValues(int size) {
|
CircularVectorValues(int size) {
|
||||||
this.size = size;
|
this.size = size;
|
||||||
value = new float[2];
|
value = new float[2];
|
||||||
|
binaryValue = new BytesRef(new byte[2]);
|
||||||
}
|
}
|
||||||
|
|
||||||
public CircularVectorValues copy() {
|
public CircularVectorValues copy() {
|
||||||
@ -617,7 +687,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int ord) {
|
public BytesRef binaryValue(int ord) {
|
||||||
return null;
|
float[] vectorValue = vectorValue(ord);
|
||||||
|
for (int i = 0; i < vectorValue.length; i++) {
|
||||||
|
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
|
||||||
|
}
|
||||||
|
return binaryValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -648,8 +722,9 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
if (uDoc == NO_MORE_DOCS) {
|
if (uDoc == NO_MORE_DOCS) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
|
||||||
assertArrayEquals(
|
assertArrayEquals(
|
||||||
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), 1e-4f);
|
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -657,7 +732,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
static class RandomVectorValues extends MockVectorValues {
|
static class RandomVectorValues extends MockVectorValues {
|
||||||
|
|
||||||
RandomVectorValues(int size, int dimension, Random random) {
|
RandomVectorValues(int size, int dimension, Random random) {
|
||||||
super(createRandomVectors(size, dimension, random));
|
super(createRandomVectors(size, dimension, null, random));
|
||||||
|
}
|
||||||
|
|
||||||
|
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||||
|
super(createRandomVectors(size, dimension, vectorEncoding, random));
|
||||||
}
|
}
|
||||||
|
|
||||||
RandomVectorValues(RandomVectorValues other) {
|
RandomVectorValues(RandomVectorValues other) {
|
||||||
@ -669,11 +748,21 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
return new RandomVectorValues(this);
|
return new RandomVectorValues(this);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float[][] createRandomVectors(int size, int dimension, Random random) {
|
private static float[][] createRandomVectors(
|
||||||
|
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||||
float[][] vectors = new float[size][];
|
float[][] vectors = new float[size][];
|
||||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||||
vectors[offset] = randomVector(random, dimension);
|
vectors[offset] = randomVector(random, dimension);
|
||||||
}
|
}
|
||||||
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
|
for (float[] vector : vectors) {
|
||||||
|
if (vector != null) {
|
||||||
|
for (int i = 0; i < vector.length; i++) {
|
||||||
|
vector[i] = (byte) (127 * vector[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -701,8 +790,19 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||||||
float[] vec = new float[dim];
|
float[] vec = new float[dim];
|
||||||
for (int i = 0; i < dim; i++) {
|
for (int i = 0; i < dim; i++) {
|
||||||
vec[i] = random.nextFloat();
|
vec[i] = random.nextFloat();
|
||||||
|
if (random.nextBoolean()) {
|
||||||
|
vec[i] = -vec[i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
VectorUtil.l2normalize(vec);
|
VectorUtil.l2normalize(vec);
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static float[] randomVector8(Random random, int dim) {
|
||||||
|
float[] fvec = randomVector(random, dim);
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
fvec[i] *= 127;
|
||||||
|
}
|
||||||
|
return fvec;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -34,8 +34,10 @@ import org.apache.lucene.index.SortedNumericDocValues;
|
|||||||
import org.apache.lucene.index.SortedSetDocValues;
|
import org.apache.lucene.index.SortedSetDocValues;
|
||||||
import org.apache.lucene.index.StoredFieldVisitor;
|
import org.apache.lucene.index.StoredFieldVisitor;
|
||||||
import org.apache.lucene.index.Terms;
|
import org.apache.lucene.index.Terms;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
import org.apache.lucene.util.Version;
|
import org.apache.lucene.util.Version;
|
||||||
@ -97,6 +99,7 @@ public class TermVectorLeafReader extends LeafReader {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false);
|
false);
|
||||||
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
|
fieldInfos = new FieldInfos(new FieldInfo[] {fieldInfo});
|
||||||
@ -166,6 +169,12 @@ public class TermVectorLeafReader extends LeafReader {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {}
|
public void checkIntegrity() throws IOException {}
|
||||||
|
|
||||||
|
@ -37,6 +37,7 @@ import org.apache.lucene.document.FieldType;
|
|||||||
import org.apache.lucene.index.*;
|
import org.apache.lucene.index.*;
|
||||||
import org.apache.lucene.search.Collector;
|
import org.apache.lucene.search.Collector;
|
||||||
import org.apache.lucene.search.CollectorManager;
|
import org.apache.lucene.search.CollectorManager;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.Query;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.Scorable;
|
import org.apache.lucene.search.Scorable;
|
||||||
@ -514,6 +515,7 @@ public class MemoryIndex {
|
|||||||
fieldType.pointIndexDimensionCount(),
|
fieldType.pointIndexDimensionCount(),
|
||||||
fieldType.pointNumBytes(),
|
fieldType.pointNumBytes(),
|
||||||
fieldType.vectorDimension(),
|
fieldType.vectorDimension(),
|
||||||
|
fieldType.vectorEncoding(),
|
||||||
fieldType.vectorSimilarityFunction(),
|
fieldType.vectorSimilarityFunction(),
|
||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
@ -546,6 +548,7 @@ public class MemoryIndex {
|
|||||||
info.fieldInfo.getPointIndexDimensionCount(),
|
info.fieldInfo.getPointIndexDimensionCount(),
|
||||||
info.fieldInfo.getPointNumBytes(),
|
info.fieldInfo.getPointNumBytes(),
|
||||||
info.fieldInfo.getVectorDimension(),
|
info.fieldInfo.getVectorDimension(),
|
||||||
|
info.fieldInfo.getVectorEncoding(),
|
||||||
info.fieldInfo.getVectorSimilarityFunction(),
|
info.fieldInfo.getVectorSimilarityFunction(),
|
||||||
info.fieldInfo.isSoftDeletesField());
|
info.fieldInfo.isSoftDeletesField());
|
||||||
} else if (existingDocValuesType != docValuesType) {
|
} else if (existingDocValuesType != docValuesType) {
|
||||||
@ -1371,6 +1374,12 @@ public class MemoryIndex {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
// no-op
|
// no-op
|
||||||
|
@ -29,6 +29,7 @@ import org.apache.lucene.index.SegmentReadState;
|
|||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
@ -61,7 +62,7 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
return delegate.addField(fieldInfo);
|
return delegate.addField(fieldInfo);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -131,6 +132,18 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||||||
return hits;
|
return hits;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
FieldInfo fi = fis.fieldInfo(field);
|
||||||
|
assert fi != null && fi.getVectorDimension() > 0;
|
||||||
|
assert acceptDocs != null;
|
||||||
|
TopDocs hits = delegate.searchExhaustively(field, target, k, acceptDocs);
|
||||||
|
assert hits != null;
|
||||||
|
assert hits.scoreDocs.length <= k;
|
||||||
|
return hits;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
delegate.close();
|
delegate.close();
|
||||||
|
@ -36,6 +36,7 @@ import org.apache.lucene.index.IndexOptions;
|
|||||||
import org.apache.lucene.index.IndexableFieldType;
|
import org.apache.lucene.index.IndexableFieldType;
|
||||||
import org.apache.lucene.index.PointValues;
|
import org.apache.lucene.index.PointValues;
|
||||||
import org.apache.lucene.index.SegmentInfo;
|
import org.apache.lucene.index.SegmentInfo;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
||||||
@ -305,6 +306,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||||||
fieldType.pointIndexDimensionCount(),
|
fieldType.pointIndexDimensionCount(),
|
||||||
fieldType.pointNumBytes(),
|
fieldType.pointNumBytes(),
|
||||||
fieldType.vectorDimension(),
|
fieldType.vectorDimension(),
|
||||||
|
fieldType.vectorEncoding(),
|
||||||
fieldType.vectorSimilarityFunction(),
|
fieldType.vectorSimilarityFunction(),
|
||||||
field.equals(softDeletesField));
|
field.equals(softDeletesField));
|
||||||
addAttributes(fi);
|
addAttributes(fi);
|
||||||
@ -353,7 +355,8 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||||||
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
|
int dimension = 1 + r.nextInt(VectorValues.MAX_DIMENSIONS);
|
||||||
VectorSimilarityFunction similarityFunction =
|
VectorSimilarityFunction similarityFunction =
|
||||||
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
|
RandomPicks.randomFrom(r, VectorSimilarityFunction.values());
|
||||||
type.setVectorDimensionsAndSimilarityFunction(dimension, similarityFunction);
|
VectorEncoding encoding = RandomPicks.randomFrom(r, VectorEncoding.values());
|
||||||
|
type.setVectorAttributes(dimension, encoding, similarityFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
return type;
|
return type;
|
||||||
@ -422,6 +425,7 @@ public abstract class BaseFieldInfoFormatTestCase extends BaseIndexFileFormatTes
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
@ -360,6 +360,7 @@ abstract class BaseIndexFileFormatTestCase extends LuceneTestCase {
|
|||||||
proto.getPointIndexDimensionCount(),
|
proto.getPointIndexDimensionCount(),
|
||||||
proto.getPointNumBytes(),
|
proto.getPointNumBytes(),
|
||||||
proto.getVectorDimension(),
|
proto.getVectorDimension(),
|
||||||
|
proto.getVectorEncoding(),
|
||||||
proto.getVectorSimilarityFunction(),
|
proto.getVectorSimilarityFunction(),
|
||||||
proto.isSoftDeletesField());
|
proto.isSoftDeletesField());
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ import org.apache.lucene.index.IndexWriterConfig;
|
|||||||
import org.apache.lucene.index.LeafReader;
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.Term;
|
import org.apache.lucene.index.Term;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
@ -51,8 +52,10 @@ import org.apache.lucene.store.Directory;
|
|||||||
import org.apache.lucene.store.FSDirectory;
|
import org.apache.lucene.store.FSDirectory;
|
||||||
import org.apache.lucene.tests.util.TestUtil;
|
import org.apache.lucene.tests.util.TestUtil;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all
|
* Base class aiming at testing {@link KnnVectorsFormat vectors formats}. To test a new format, all
|
||||||
@ -63,9 +66,21 @@ import org.apache.lucene.util.VectorUtil;
|
|||||||
*/
|
*/
|
||||||
public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
|
public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTestCase {
|
||||||
|
|
||||||
|
private VectorEncoding vectorEncoding;
|
||||||
|
private VectorSimilarityFunction similarityFunction;
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void init() {
|
||||||
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
similarityFunction = randomSimilarity();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected void addRandomFields(Document doc) {
|
protected void addRandomFields(Document doc) {
|
||||||
doc.add(new KnnVectorField("v2", randomVector(30), VectorSimilarityFunction.EUCLIDEAN));
|
switch (vectorEncoding) {
|
||||||
|
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
|
||||||
|
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testFieldConstructor() {
|
public void testFieldConstructor() {
|
||||||
@ -133,8 +148,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||||
String errMsg =
|
String errMsg =
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=3, vector similarity function=DOT_PRODUCT";
|
+ "to inconsistent vector dimension=3, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT";
|
||||||
assertEquals(errMsg, expected.getMessage());
|
assertEquals(errMsg, expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -170,8 +185,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
expectThrows(IllegalArgumentException.class, () -> w.addDocument(doc2));
|
||||||
String errMsg =
|
String errMsg =
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN";
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN";
|
||||||
assertEquals(errMsg, expected.getMessage());
|
assertEquals(errMsg, expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -190,8 +205,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=1, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=1, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -211,8 +226,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
expectThrows(IllegalArgumentException.class, () -> w2.addDocument(doc2));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=EUCLIDEAN",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -311,8 +326,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
expectThrows(
|
expectThrows(
|
||||||
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
|
IllegalArgumentException.class, () -> w2.addIndexes(new Directory[] {dir}));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -333,8 +348,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
|
expectThrows(IllegalArgumentException.class, () -> w2.addIndexes(dir));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -358,8 +373,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -384,8 +399,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
() -> w2.addIndexes(new CodecReader[] {(CodecReader) getOnlyLeafReader(r)}));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -408,8 +423,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=5, vector similarity function=DOT_PRODUCT "
|
"cannot change field \"f\" from vector dimension=5, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -432,8 +447,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
IllegalArgumentException expected =
|
IllegalArgumentException expected =
|
||||||
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
expectThrows(IllegalArgumentException.class, () -> TestUtil.addIndexesSlowly(w2, r));
|
||||||
assertEquals(
|
assertEquals(
|
||||||
"cannot change field \"f\" from vector dimension=4, vector similarity function=EUCLIDEAN "
|
"cannot change field \"f\" from vector dimension=4, vector encoding=FLOAT32, vector similarity function=EUCLIDEAN "
|
||||||
+ "to inconsistent vector dimension=4, vector similarity function=DOT_PRODUCT",
|
+ "to inconsistent vector dimension=4, vector encoding=FLOAT32, vector similarity function=DOT_PRODUCT",
|
||||||
expected.getMessage());
|
expected.getMessage());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -596,12 +611,12 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
int[] fieldDocCounts = new int[numFields];
|
int[] fieldDocCounts = new int[numFields];
|
||||||
double[] fieldTotals = new double[numFields];
|
double[] fieldTotals = new double[numFields];
|
||||||
int[] fieldDims = new int[numFields];
|
int[] fieldDims = new int[numFields];
|
||||||
VectorSimilarityFunction[] fieldSearchStrategies = new VectorSimilarityFunction[numFields];
|
VectorSimilarityFunction[] fieldSimilarityFunctions = new VectorSimilarityFunction[numFields];
|
||||||
|
VectorEncoding[] fieldVectorEncodings = new VectorEncoding[numFields];
|
||||||
for (int i = 0; i < numFields; i++) {
|
for (int i = 0; i < numFields; i++) {
|
||||||
fieldDims[i] = random().nextInt(20) + 1;
|
fieldDims[i] = random().nextInt(20) + 1;
|
||||||
fieldSearchStrategies[i] =
|
fieldSimilarityFunctions[i] = randomSimilarity();
|
||||||
VectorSimilarityFunction.values()[
|
fieldVectorEncodings[i] = randomVectorEncoding();
|
||||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
|
||||||
}
|
}
|
||||||
try (Directory dir = newDirectory();
|
try (Directory dir = newDirectory();
|
||||||
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
|
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
|
||||||
@ -610,15 +625,23 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
for (int field = 0; field < numFields; field++) {
|
for (int field = 0; field < numFields; field++) {
|
||||||
String fieldName = "int" + field;
|
String fieldName = "int" + field;
|
||||||
if (random().nextInt(100) == 17) {
|
if (random().nextInt(100) == 17) {
|
||||||
float[] v = randomVector(fieldDims[field]);
|
switch (fieldVectorEncodings[field]) {
|
||||||
doc.add(new KnnVectorField(fieldName, v, fieldSearchStrategies[field]));
|
case BYTE -> {
|
||||||
|
BytesRef b = randomVector8(fieldDims[field]);
|
||||||
|
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
|
||||||
|
fieldTotals[field] += b.bytes[b.offset];
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
float[] v = randomVector(fieldDims[field]);
|
||||||
|
doc.add(new KnnVectorField(fieldName, v, fieldSimilarityFunctions[field]));
|
||||||
|
fieldTotals[field] += v[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
fieldDocCounts[field]++;
|
fieldDocCounts[field]++;
|
||||||
fieldTotals[field] += v[0];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
w.addDocument(doc);
|
w.addDocument(doc);
|
||||||
}
|
}
|
||||||
|
|
||||||
try (IndexReader r = w.getReader()) {
|
try (IndexReader r = w.getReader()) {
|
||||||
for (int field = 0; field < numFields; field++) {
|
for (int field = 0; field < numFields; field++) {
|
||||||
int docCount = 0;
|
int docCount = 0;
|
||||||
@ -634,12 +657,29 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertEquals(fieldDocCounts[field], docCount);
|
assertEquals(fieldDocCounts[field], docCount);
|
||||||
assertEquals(fieldTotals[field], checksum, 1e-5);
|
// Account for quantization done when indexing fields w/BYTE encoding
|
||||||
|
double delta = fieldVectorEncodings[field] == VectorEncoding.BYTE ? numDocs * 0.01 : 1e-5;
|
||||||
|
assertEquals(fieldTotals[field], checksum, delta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private VectorSimilarityFunction randomSimilarity() {
|
||||||
|
return VectorSimilarityFunction.values()[
|
||||||
|
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorEncoding randomVectorEncoding() {
|
||||||
|
Codec codec = getCodec();
|
||||||
|
if (codec.knnVectorsFormat().currentVersion()
|
||||||
|
>= Codec.forName("Lucene94").knnVectorsFormat().currentVersion()) {
|
||||||
|
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||||
|
} else {
|
||||||
|
return VectorEncoding.FLOAT32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public void testIndexedValueNotAliased() throws Exception {
|
public void testIndexedValueNotAliased() throws Exception {
|
||||||
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
|
// We copy indexed values (as for BinaryDocValues) so the input float[] can be reused across
|
||||||
// calls to IndexWriter.addDocument.
|
// calls to IndexWriter.addDocument.
|
||||||
@ -742,7 +782,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
assertEquals(3, vectorValues3.dimension());
|
assertEquals(3, vectorValues3.dimension());
|
||||||
assertEquals(1, vectorValues3.size());
|
assertEquals(1, vectorValues3.size());
|
||||||
vectorValues3.nextDoc();
|
vectorValues3.nextDoc();
|
||||||
assertEquals(1f, vectorValues3.vectorValue()[0], 0);
|
assertEquals(1f, vectorValues3.vectorValue()[0], 0.1);
|
||||||
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
|
assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -775,9 +815,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
if (random().nextBoolean() && values[i] != null) {
|
if (random().nextBoolean() && values[i] != null) {
|
||||||
// sometimes use a shared scratch array
|
// sometimes use a shared scratch array
|
||||||
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
System.arraycopy(values[i], 0, scratch, 0, scratch.length);
|
||||||
add(iw, fieldName, i, scratch, VectorSimilarityFunction.EUCLIDEAN);
|
add(iw, fieldName, i, scratch, similarityFunction);
|
||||||
} else {
|
} else {
|
||||||
add(iw, fieldName, i, values[i], VectorSimilarityFunction.EUCLIDEAN);
|
add(iw, fieldName, i, values[i], similarityFunction);
|
||||||
}
|
}
|
||||||
if (random().nextInt(10) == 2) {
|
if (random().nextInt(10) == 2) {
|
||||||
// sometimes delete a random document
|
// sometimes delete a random document
|
||||||
@ -898,7 +938,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
int numDoc = atLeast(100);
|
int numDoc = atLeast(100);
|
||||||
int dimension = atLeast(10);
|
int dimension = atLeast(10);
|
||||||
float[][] id2value = new float[numDoc][];
|
float[][] id2value = new float[numDoc][];
|
||||||
int[] id2ord = new int[numDoc];
|
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
int id = random().nextInt(numDoc);
|
int id = random().nextInt(numDoc);
|
||||||
float[] value;
|
float[] value;
|
||||||
@ -909,7 +948,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
value = null;
|
value = null;
|
||||||
}
|
}
|
||||||
id2value[id] = value;
|
id2value[id] = value;
|
||||||
id2ord[id] = i;
|
|
||||||
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
|
add(iw, fieldName, id, value, VectorSimilarityFunction.EUCLIDEAN);
|
||||||
}
|
}
|
||||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||||
@ -1007,6 +1045,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private BytesRef randomVector8(int dim) {
|
||||||
|
float[] v = randomVector(dim);
|
||||||
|
byte[] b = new byte[dim];
|
||||||
|
for (int i = 0; i < dim; i++) {
|
||||||
|
b[i] = (byte) (v[i] * 127);
|
||||||
|
}
|
||||||
|
return new BytesRef(b);
|
||||||
|
}
|
||||||
|
|
||||||
public void testCheckIndexIncludesVectors() throws Exception {
|
public void testCheckIndexIncludesVectors() throws Exception {
|
||||||
try (Directory dir = newDirectory()) {
|
try (Directory dir = newDirectory()) {
|
||||||
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||||
@ -1041,6 +1088,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
assertEquals(3, VectorSimilarityFunction.values().length);
|
assertEquals(3, VectorSimilarityFunction.values().length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testVectorEncodingOrdinals() {
|
||||||
|
// make sure we don't accidentally mess up vector encoding identifiers by re-ordering their
|
||||||
|
// enumerators
|
||||||
|
assertEquals(0, VectorEncoding.BYTE.ordinal());
|
||||||
|
assertEquals(1, VectorEncoding.FLOAT32.ordinal());
|
||||||
|
assertEquals(2, VectorEncoding.values().length);
|
||||||
|
}
|
||||||
|
|
||||||
public void testAdvance() throws Exception {
|
public void testAdvance() throws Exception {
|
||||||
try (Directory dir = newDirectory()) {
|
try (Directory dir = newDirectory()) {
|
||||||
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||||
@ -1091,10 +1146,6 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
public void testVectorValuesReportCorrectDocs() throws Exception {
|
public void testVectorValuesReportCorrectDocs() throws Exception {
|
||||||
final int numDocs = atLeast(1000);
|
final int numDocs = atLeast(1000);
|
||||||
final int dim = random().nextInt(20) + 1;
|
final int dim = random().nextInt(20) + 1;
|
||||||
final VectorSimilarityFunction similarityFunction =
|
|
||||||
VectorSimilarityFunction.values()[
|
|
||||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
|
||||||
|
|
||||||
double fieldValuesCheckSum = 0;
|
double fieldValuesCheckSum = 0;
|
||||||
int fieldDocCount = 0;
|
int fieldDocCount = 0;
|
||||||
long fieldSumDocIDs = 0;
|
long fieldSumDocIDs = 0;
|
||||||
@ -1106,9 +1157,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
int docID = random().nextInt(numDocs);
|
int docID = random().nextInt(numDocs);
|
||||||
doc.add(new StoredField("id", docID));
|
doc.add(new StoredField("id", docID));
|
||||||
if (random().nextInt(4) == 3) {
|
if (random().nextInt(4) == 3) {
|
||||||
float[] vector = randomVector(dim);
|
switch (vectorEncoding) {
|
||||||
doc.add(new KnnVectorField("knn_vector", vector, similarityFunction));
|
case BYTE -> {
|
||||||
fieldValuesCheckSum += vector[0];
|
BytesRef b = randomVector8(dim);
|
||||||
|
fieldValuesCheckSum += b.bytes[b.offset];
|
||||||
|
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
float[] v = randomVector(dim);
|
||||||
|
fieldValuesCheckSum += v[0];
|
||||||
|
doc.add(new KnnVectorField("knn_vector", v, similarityFunction));
|
||||||
|
}
|
||||||
|
}
|
||||||
fieldDocCount++;
|
fieldDocCount++;
|
||||||
fieldSumDocIDs += docID;
|
fieldSumDocIDs += docID;
|
||||||
}
|
}
|
||||||
@ -1134,7 +1194,10 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
assertEquals(fieldValuesCheckSum, checksum, 1e-3);
|
assertEquals(
|
||||||
|
fieldValuesCheckSum,
|
||||||
|
checksum,
|
||||||
|
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
|
||||||
assertEquals(fieldDocCount, docCount);
|
assertEquals(fieldDocCount, docCount);
|
||||||
assertEquals(fieldSumDocIDs, sumDocIds);
|
assertEquals(fieldSumDocIDs, sumDocIds);
|
||||||
}
|
}
|
||||||
|
@ -40,6 +40,7 @@ import org.apache.lucene.index.SortedSetDocValues;
|
|||||||
import org.apache.lucene.index.StoredFieldVisitor;
|
import org.apache.lucene.index.StoredFieldVisitor;
|
||||||
import org.apache.lucene.index.Terms;
|
import org.apache.lucene.index.Terms;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.TopDocs;
|
import org.apache.lucene.search.TopDocs;
|
||||||
import org.apache.lucene.util.Bits;
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
@ -228,6 +229,12 @@ class MergeReaderWrapper extends LeafReader {
|
|||||||
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
return in.searchNearestVectors(field, target, k, acceptDocs, visitedLimit);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) throws IOException {
|
||||||
|
return in.searchNearestVectorsExhaustively(field, target, k, acceptDocs);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int numDocs() {
|
public int numDocs() {
|
||||||
return in.numDocs();
|
return in.numDocs();
|
||||||
|
@ -88,6 +88,7 @@ public class MismatchedLeafReader extends FilterLeafReader {
|
|||||||
oldInfo.getPointIndexDimensionCount(), // index dimension count
|
oldInfo.getPointIndexDimensionCount(), // index dimension count
|
||||||
oldInfo.getPointNumBytes(), // dimension numBytes
|
oldInfo.getPointNumBytes(), // dimension numBytes
|
||||||
oldInfo.getVectorDimension(), // number of dimensions of the field's vector
|
oldInfo.getVectorDimension(), // number of dimensions of the field's vector
|
||||||
|
oldInfo.getVectorEncoding(), // numeric type of vector samples
|
||||||
// distance function for calculating similarity of the field's vector
|
// distance function for calculating similarity of the field's vector
|
||||||
oldInfo.getVectorSimilarityFunction(),
|
oldInfo.getVectorSimilarityFunction(),
|
||||||
oldInfo.isSoftDeletesField()); // used as soft-deletes field
|
oldInfo.isSoftDeletesField()); // used as soft-deletes field
|
||||||
|
@ -62,6 +62,7 @@ import org.apache.lucene.index.SegmentWriteState;
|
|||||||
import org.apache.lucene.index.TermState;
|
import org.apache.lucene.index.TermState;
|
||||||
import org.apache.lucene.index.Terms;
|
import org.apache.lucene.index.Terms;
|
||||||
import org.apache.lucene.index.TermsEnum;
|
import org.apache.lucene.index.TermsEnum;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
import org.apache.lucene.internal.tests.IndexPackageAccess;
|
||||||
import org.apache.lucene.internal.tests.TestSecrets;
|
import org.apache.lucene.internal.tests.TestSecrets;
|
||||||
@ -163,6 +164,7 @@ public class RandomPostingsTester {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false);
|
false);
|
||||||
fieldUpto++;
|
fieldUpto++;
|
||||||
@ -734,6 +736,7 @@ public class RandomPostingsTester {
|
|||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
0,
|
0,
|
||||||
|
VectorEncoding.FLOAT32,
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
false);
|
false);
|
||||||
}
|
}
|
||||||
|
@ -233,6 +233,12 @@ public class QueryUtils {
|
|||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public TopDocs searchNearestVectorsExhaustively(
|
||||||
|
String field, float[] target, int k, DocIdSetIterator acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FieldInfos getFieldInfos() {
|
public FieldInfos getFieldInfos() {
|
||||||
return FieldInfos.EMPTY;
|
return FieldInfos.EMPTY;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user