mirror of https://github.com/apache/lucene.git
Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064)
This commit is contained in:
parent
e14327288e
commit
cc29102a24
|
@ -48,7 +48,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
private final Lucene90NeighborArray scratch;
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final RandomAccessVectorValues<float[]> vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final Lucene90BoundsChecker bound;
|
||||
final Lucene90OnHeapHnswGraph hnsw;
|
||||
|
@ -57,7 +57,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private final RandomAccessVectorValues buildVectors;
|
||||
private final RandomAccessVectorValues<float[]> buildVectors;
|
||||
|
||||
/**
|
||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||
|
@ -72,7 +72,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene90HnswGraphBuilder(
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<float[]> vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
|
@ -103,7 +103,8 @@ public final class Lucene90HnswGraphBuilder {
|
|||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
|
||||
throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -229,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
|
|||
float[] candidate,
|
||||
float score,
|
||||
Lucene90NeighborArray neighbors,
|
||||
RandomAccessVectorValues vectorValues)
|
||||
RandomAccessVectorValues<float[]> vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Map;
|
|||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -232,6 +233,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
return getOffHeapVectorValues(fieldEntry);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
@ -352,7 +358,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
static class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
final int dimension;
|
||||
final int[] ordToDoc;
|
||||
|
@ -433,7 +440,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() {
|
||||
public RandomAccessVectorValues<float[]> copy() {
|
||||
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
|
||||
}
|
||||
|
||||
|
@ -443,17 +450,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
dataIn.readFloats(value, 0, value.length);
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
dataIn.seek((long) targetOrd * byteSize);
|
||||
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
|
|
|
@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
|||
float[] query,
|
||||
int topK,
|
||||
int numSeed,
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<float[]> vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graphValues,
|
||||
Bits acceptOrds,
|
||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Map;
|
|||
import java.util.function.IntUnaryOperator;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -224,6 +225,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
return getOffHeapVectorValues(fieldEntry);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
@ -398,7 +404,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
static class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
private final int dimension;
|
||||
private final int size;
|
||||
|
@ -486,7 +493,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() {
|
||||
public RandomAccessVectorValues<float[]> copy() {
|
||||
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
|
||||
}
|
||||
|
||||
|
@ -496,17 +503,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
dataIn.readFloats(value, 0, value.length);
|
||||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
dataIn.seek((long) targetOrd * byteSize);
|
||||
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -219,6 +220,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
|||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
|
|
@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
|||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
abstract class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
public abstract int ordToDoc(int ord);
|
||||
|
||||
static OffHeapVectorValues load(
|
||||
|
@ -137,7 +127,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
}
|
||||
|
||||
|
@ -210,7 +200,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
||||
}
|
||||
|
||||
|
@ -282,7 +272,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -291,11 +281,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -233,12 +234,31 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
||||
return new ExpandingVectorValues(values);
|
||||
} else {
|
||||
return values;
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -292,7 +312,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||
|
||||
NeighborQueue results =
|
||||
HnswGraphSearcher.search(
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.backward_codecs.lucene94;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final IndexInput slice;
|
||||
protected final BytesRef binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
|
||||
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
public abstract int ordToDoc(int ord);
|
||||
|
||||
static OffHeapByteVectorValues load(
|
||||
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
int byteSize = fieldEntry.dimension;
|
||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
|
||||
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
slice.seek((long) doc * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
private final DirectMonotonicReader ordToDoc;
|
||||
private final IndexedDISI disi;
|
||||
// dataIn was used to init a new IndexedDIS for #randomAccess()
|
||||
private final IndexInput dataIn;
|
||||
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
|
||||
|
||||
public SparseOffHeapVectorValues(
|
||||
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
int byteSize)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||
this.dataIn = dataIn;
|
||||
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
|
||||
this.disi =
|
||||
new IndexedDISI(
|
||||
dataIn,
|
||||
fieldEntry.docsWithFieldOffset,
|
||||
fieldEntry.docsWithFieldLength,
|
||||
fieldEntry.jumpTableEntryCount,
|
||||
fieldEntry.denseRankPower,
|
||||
fieldEntry.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
slice.seek((long) (disi.index()) * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return (int) ordToDoc.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
return new Bits() {
|
||||
@Override
|
||||
public boolean get(int index) {
|
||||
return acceptDocs.get(ordToDoc(index));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return size;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
|||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
abstract class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
public abstract int ordToDoc(int ord);
|
||||
|
||||
static OffHeapVectorValues load(
|
||||
|
@ -143,7 +133,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -219,7 +209,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -291,7 +281,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -300,11 +290,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
|
||||
private void writeGraph(
|
||||
IndexOutput graphData,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
RandomAccessVectorValues<float[]> vectorValues,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
long graphDataOffset,
|
||||
long[] offsets,
|
||||
|
|
|
@ -53,7 +53,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
private final Lucene91NeighborArray scratch;
|
||||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final RandomAccessVectorValues<float[]> vectorValues;
|
||||
private final SplittableRandom random;
|
||||
private final Lucene91BoundsChecker bound;
|
||||
private final HnswGraphSearcher<float[]> graphSearcher;
|
||||
|
@ -64,7 +64,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private RandomAccessVectorValues buildVectors;
|
||||
private RandomAccessVectorValues<float[]> buildVectors;
|
||||
|
||||
/**
|
||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||
|
@ -79,7 +79,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene91HnswGraphBuilder(
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<float[]> vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
|
@ -119,7 +119,8 @@ public final class Lucene91HnswGraphBuilder {
|
|||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||
* accessor for the vectors
|
||||
*/
|
||||
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
|
||||
throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -250,7 +251,7 @@ public final class Lucene91HnswGraphBuilder {
|
|||
float[] candidate,
|
||||
float score,
|
||||
Lucene91NeighborArray neighbors,
|
||||
RandomAccessVectorValues vectorValues)
|
||||
RandomAccessVectorValues<float[]> vectorValues)
|
||||
throws IOException {
|
||||
bound.set(score);
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
|
|
|
@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private Lucene91OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
|
|
|
@ -268,13 +268,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValues vectorValues,
|
||||
RandomAccessVectorValues<float[]> vectorValues,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
vectorEncoding,
|
||||
|
|
|
@ -30,12 +30,14 @@ import org.apache.lucene.codecs.CodecUtil;
|
|||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
@ -379,8 +381,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
vectorData.getName(), "temp", segmentWriteState.context);
|
||||
|
@ -389,7 +389,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
try {
|
||||
// write the vector data to a temporary file
|
||||
DocsWithFieldSet docsWithField =
|
||||
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
|
@ -405,23 +410,49 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||
// doesn't need to know docIds
|
||||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
OffHeapVectorValues offHeapVectors =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
||||
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
OnHeapHnswGraph graph = null;
|
||||
if (offHeapVectors.size() != 0) {
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
offHeapVectors,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(offHeapVectors.copy());
|
||||
graph =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
};
|
||||
writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -554,16 +585,37 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||
* vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = byteVectorValues.binaryValue();
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(
|
||||
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
||||
IndexOutput output, VectorValues floatVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = vectors.binaryValue();
|
||||
assert binaryValue.length == vectors.dimension() * scalarSize;
|
||||
BytesRef binaryValue = floatVectorValues.binaryValue();
|
||||
assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
|
@ -580,7 +632,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private final RAVectorValues<T> raVectorValues;
|
||||
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
@ -593,8 +644,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||
@Override
|
||||
public BytesRef copyValue(BytesRef value) {
|
||||
return new BytesRef(
|
||||
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
|
||||
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||
|
@ -613,16 +663,15 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||
RAVectorValues<T> raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||
hnswGraphBuilder =
|
||||
(HnswGraphBuilder<T>)
|
||||
HnswGraphBuilder.create(
|
||||
raVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
HnswGraphBuilder.create(
|
||||
raVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
}
|
||||
|
||||
|
@ -667,7 +716,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||
private final List<T> vectors;
|
||||
private final int dim;
|
||||
|
||||
|
@ -687,17 +736,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return (float[]) vectors.get(targetOrd);
|
||||
public T vectorValue(int targetOrd) throws IOException {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
return (BytesRef) vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RAVectorValues<T> copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets;
|
|||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
|
@ -143,6 +144,39 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||
return null;
|
||||
}
|
||||
int dimension = info.getVectorDimension();
|
||||
if (dimension == 0) {
|
||||
throw new IllegalStateException(
|
||||
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||
}
|
||||
FieldEntry fieldEntry = fieldEntries.get(field);
|
||||
if (fieldEntry == null) {
|
||||
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||
return null;
|
||||
}
|
||||
if (dimension != fieldEntry.dimension) {
|
||||
throw new IllegalStateException(
|
||||
"Inconsistent vector dimension for field=\""
|
||||
+ field
|
||||
+ "\"; "
|
||||
+ dimension
|
||||
+ " != "
|
||||
+ fieldEntry.dimension);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
@ -187,7 +221,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
VectorValues values = getVectorValues(field);
|
||||
ByteVectorValues values = getByteVectorValues(field);
|
||||
if (target.length != values.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector query dimension: "
|
||||
|
@ -213,7 +247,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
break;
|
||||
}
|
||||
|
||||
BytesRef vector = values.binaryValue();
|
||||
BytesRef vector = values.vectorValue();
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
numVisited++;
|
||||
|
@ -301,7 +335,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
private static class SimpleTextVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues {
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
private final FieldEntry entry;
|
||||
|
@ -356,7 +390,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() {
|
||||
public RandomAccessVectorValues<float[]> copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
@ -409,10 +443,99 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return values[targetOrd];
|
||||
}
|
||||
}
|
||||
|
||||
private static class SimpleTextByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
|
||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
private final FieldEntry entry;
|
||||
private final IndexInput in;
|
||||
private final BytesRef binaryValue;
|
||||
private final byte[][] values;
|
||||
|
||||
int curOrd;
|
||||
|
||||
SimpleTextByteVectorValues(FieldEntry entry, IndexInput in) throws IOException {
|
||||
this.entry = entry;
|
||||
this.in = in;
|
||||
values = new byte[entry.size()][entry.dimension];
|
||||
binaryValue = new BytesRef(entry.dimension);
|
||||
binaryValue.length = binaryValue.bytes.length;
|
||||
curOrd = -1;
|
||||
readAllVectors();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
public int dimension() {
|
||||
return entry.dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return entry.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
binaryValue.bytes = values[curOrd];
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curOrd == -1) {
|
||||
return -1;
|
||||
} else if (curOrd >= entry.size()) {
|
||||
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
|
||||
// immediately afterward should also return NO_MORE_DOCS
|
||||
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
return entry.ordToDoc[curOrd];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
if (++curOrd < entry.size()) {
|
||||
return docID();
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
private void readAllVectors() throws IOException {
|
||||
for (byte[] value : values) {
|
||||
readVector(value);
|
||||
}
|
||||
}
|
||||
|
||||
private void readVector(byte[] value) throws IOException {
|
||||
SimpleTextUtil.readLine(in, scratch);
|
||||
// skip leading "[" and strip trailing "]"
|
||||
String s = new BytesRef(scratch.bytes(), 1, scratch.length() - 2).utf8ToString();
|
||||
String[] floatStrings = s.split(",");
|
||||
assert floatStrings.length == value.length
|
||||
: " read " + s + " when expecting " + value.length + " floats";
|
||||
for (int i = 0; i < floatStrings.length; i++) {
|
||||
value[i] = (byte) Float.parseFloat(floatStrings[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
binaryValue.bytes = values[curOrd];
|
||||
return binaryValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
|
|||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
|
@ -85,6 +86,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
: vectorValues;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
@ -202,6 +208,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() {}
|
||||
};
|
||||
|
@ -228,7 +239,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, Object value) {
|
||||
public void addValue(int docID, float[] value) {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
|
@ -236,25 +247,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
float[] vectorValue =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> (float[]) value;
|
||||
case BYTE -> bytesToFloats((BytesRef) value);
|
||||
};
|
||||
docsWithField.add(docID);
|
||||
vectors.add(copyValue(vectorValue));
|
||||
vectors.add(copyValue(value));
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
private float[] bytesToFloats(BytesRef b) {
|
||||
// This is used only by SimpleTextKnnVectorsWriter
|
||||
float[] floats = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
floats[i] = b.bytes[i + b.offset];
|
||||
}
|
||||
return floats;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] copyValue(float[] vectorValue) {
|
||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
||||
|
|
|
@ -34,7 +34,7 @@ public abstract class KnnFieldVectorsWriter<T> implements Accountable {
|
|||
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
||||
* increasing order.
|
||||
*/
|
||||
public abstract void addValue(int docID, Object vectorValue) throws IOException;
|
||||
public abstract void addValue(int docID, T vectorValue) throws IOException;
|
||||
|
||||
/**
|
||||
* Used to copy values being indexed to internal storage.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.codecs;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -98,6 +99,11 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.codecs;
|
|||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
@ -51,6 +52,13 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
|||
*/
|
||||
public abstract VectorValues getVectorValues(String field) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
|
||||
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||
* never {@code null}.
|
||||
*/
|
||||
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
|
|
|
@ -21,10 +21,12 @@ import java.io.Closeable;
|
|||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocIDMerger;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
|
@ -44,13 +46,29 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
|
||||
/** Write field for merging */
|
||||
@SuppressWarnings("unchecked")
|
||||
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
|
||||
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedValues.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedValues.nextDoc()) {
|
||||
writer.addValue(doc, mergedValues.vectorValue());
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
KnnFieldVectorsWriter<BytesRef> byteWriter =
|
||||
(KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
|
||||
ByteVectorValues mergedBytes =
|
||||
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedBytes.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedBytes.nextDoc()) {
|
||||
byteWriter.addValue(doc, mergedBytes.vectorValue());
|
||||
}
|
||||
break;
|
||||
case FLOAT32:
|
||||
KnnFieldVectorsWriter<float[]> floatWriter =
|
||||
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
|
||||
VectorValues mergedFloats = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
for (int doc = mergedFloats.nextDoc();
|
||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||
doc = mergedFloats.nextDoc()) {
|
||||
floatWriter.addValue(doc, mergedFloats.vectorValue());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -104,20 +122,34 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
}
|
||||
}
|
||||
|
||||
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
|
||||
protected static class MergedVectorValues extends VectorValues {
|
||||
private final List<VectorValuesSub> subs;
|
||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
private static class ByteVectorValuesSub extends DocIDMerger.Sub {
|
||||
|
||||
private int docId;
|
||||
private VectorValuesSub current;
|
||||
final ByteVectorValues values;
|
||||
|
||||
ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return values.nextDoc();
|
||||
}
|
||||
}
|
||||
|
||||
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
|
||||
protected static final class MergedVectorValues {
|
||||
private MergedVectorValues() {}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link VectorValues}. */
|
||||
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
public static VectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
|
||||
}
|
||||
List<VectorValuesSub> subs = new ArrayList<>();
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
|
@ -128,60 +160,147 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
}
|
||||
}
|
||||
}
|
||||
return new MergedVectorValues(subs, mergeState);
|
||||
return new MergedFloat32VectorValues(subs, mergeState);
|
||||
}
|
||||
|
||||
private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState)
|
||||
/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
|
||||
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||
throws IOException {
|
||||
this.subs = subs;
|
||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||
int totalSize = 0;
|
||||
for (VectorValuesSub sub : subs) {
|
||||
totalSize += sub.values.size();
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
List<ByteVectorValuesSub> subs = new ArrayList<>();
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||
if (knnVectorsReader != null) {
|
||||
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||
if (values != null) {
|
||||
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
|
||||
}
|
||||
}
|
||||
}
|
||||
return docId;
|
||||
return new MergedByteVectorValues(subs, mergeState);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
static class MergedFloat32VectorValues extends VectorValues {
|
||||
private final List<VectorValuesSub> subs;
|
||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
private int docId;
|
||||
VectorValuesSub current;
|
||||
|
||||
private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
|
||||
throws IOException {
|
||||
this.subs = subs;
|
||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||
int totalSize = 0;
|
||||
for (VectorValuesSub sub : subs) {
|
||||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
return current.values.binaryValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
return current.values.binaryValue();
|
||||
}
|
||||
static class MergedByteVectorValues extends ByteVectorValues {
|
||||
private final List<ByteVectorValuesSub> subs;
|
||||
private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
private int docId;
|
||||
ByteVectorValuesSub current;
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
|
||||
throws IOException {
|
||||
this.subs = subs;
|
||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||
int totalSize = 0;
|
||||
for (ByteVectorValuesSub sub : subs) {
|
||||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FilterVectorValues;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/** reads from byte-encoded data */
|
||||
class ExpandingVectorValues extends FilterVectorValues {
|
||||
|
||||
private final float[] value;
|
||||
|
||||
/**
|
||||
* @param in the wrapped values
|
||||
*/
|
||||
protected ExpandingVectorValues(VectorValues in) {
|
||||
super(in);
|
||||
value = new float[in.dimension()];
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
BytesRef binaryValue = binaryValue();
|
||||
byte[] bytes = binaryValue.bytes;
|
||||
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
|
||||
value[i] = bytes[j];
|
||||
}
|
||||
return value;
|
||||
}
|
||||
}
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
|||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -238,12 +239,31 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
|||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
||||
return new ExpandingVectorValues(values);
|
||||
} else {
|
||||
return values;
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -303,7 +323,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
|||
|
||||
// bound k by total number of vectors to prevent oversizing data structures
|
||||
k = Math.min(k, fieldEntry.size());
|
||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||
|
||||
NeighborQueue results =
|
||||
HnswGraphSearcher.search(
|
||||
|
|
|
@ -391,17 +391,21 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||
|
||||
IndexOutput tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
vectorData.getName(), "temp", segmentWriteState.context);
|
||||
IndexInput vectorDataInput = null;
|
||||
boolean success = false;
|
||||
try {
|
||||
// write the vector data to a temporary file
|
||||
// write the vector data to a temporary file
|
||||
DocsWithFieldSet docsWithField =
|
||||
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
|
||||
|
@ -417,24 +421,50 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
// we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||
// doesn't need to know docIds
|
||||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
OffHeapVectorValues offHeapVectors =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
||||
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
OnHeapHnswGraph graph = null;
|
||||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (offHeapVectors.size() != 0) {
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
offHeapVectors,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(offHeapVectors.copy());
|
||||
graph =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> {
|
||||
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize);
|
||||
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(
|
||||
vectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||
}
|
||||
};
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -605,16 +635,37 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||
* vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = byteVectorValues.binaryValue();
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(
|
||||
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
||||
IndexOutput output, VectorValues floatVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
BytesRef binaryValue = vectors.binaryValue();
|
||||
assert binaryValue.length == vectors.dimension() * scalarSize;
|
||||
BytesRef binaryValue = floatVectorValues.binaryValue();
|
||||
assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
|
||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
|
@ -631,7 +682,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private final RAVectorValues<T> raVectorValues;
|
||||
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||
|
||||
private int lastDocID = -1;
|
||||
|
@ -657,36 +707,31 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
vectors = new ArrayList<>();
|
||||
raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||
hnswGraphBuilder =
|
||||
(HnswGraphBuilder<T>)
|
||||
HnswGraphBuilder.create(
|
||||
raVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
HnswGraphBuilder.create(
|
||||
new RAVectorValues<>(vectors, dim),
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("unchecked")
|
||||
public void addValue(int docID, Object value) throws IOException {
|
||||
public void addValue(int docID, T vectorValue) throws IOException {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
T vectorValue = (T) value;
|
||||
assert docID > lastDocID;
|
||||
docsWithField.add(docID);
|
||||
vectors.add(copyValue(vectorValue));
|
||||
|
@ -719,7 +764,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
}
|
||||
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||
private final List<T> vectors;
|
||||
private final int dim;
|
||||
|
||||
|
@ -739,17 +784,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return (float[]) vectors.get(targetOrd);
|
||||
public T vectorValue(int targetOrd) throws IOException {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
return (BytesRef) vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<T> copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,283 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final IndexInput slice;
|
||||
protected final BytesRef binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
|
||||
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = byteSize;
|
||||
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
public abstract int ordToDoc(int ord);
|
||||
|
||||
static OffHeapByteVectorValues load(
|
||||
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
int byteSize = fieldEntry.dimension;
|
||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||
}
|
||||
}
|
||||
|
||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
|
||||
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||
super(dimension, size, slice, byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
slice.seek((long) doc * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return ord;
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
private final DirectMonotonicReader ordToDoc;
|
||||
private final IndexedDISI disi;
|
||||
// dataIn was used to init a new IndexedDIS for #randomAccess()
|
||||
private final IndexInput dataIn;
|
||||
private final Lucene95HnswVectorsReader.FieldEntry fieldEntry;
|
||||
|
||||
public SparseOffHeapVectorValues(
|
||||
Lucene95HnswVectorsReader.FieldEntry fieldEntry,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice,
|
||||
int byteSize)
|
||||
throws IOException {
|
||||
|
||||
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||
this.fieldEntry = fieldEntry;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||
this.dataIn = dataIn;
|
||||
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
|
||||
this.disi =
|
||||
new IndexedDISI(
|
||||
dataIn,
|
||||
fieldEntry.docsWithFieldOffset,
|
||||
fieldEntry.docsWithFieldLength,
|
||||
fieldEntry.jumpTableEntryCount,
|
||||
fieldEntry.denseRankPower,
|
||||
fieldEntry.size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
slice.seek((long) (disi.index()) * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return (int) ordToDoc.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
return new Bits() {
|
||||
@Override
|
||||
public boolean get(int index) {
|
||||
return acceptDocs.get(ordToDoc(index));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return size;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null, 0);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ package org.apache.lucene.codecs.lucene95;
|
|||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
|
@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
|||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
abstract class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -66,31 +68,17 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
return value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
readValue(targetOrd);
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private void readValue(int targetOrd) throws IOException {
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||
}
|
||||
|
||||
public abstract int ordToDoc(int ord);
|
||||
|
||||
static OffHeapVectorValues load(
|
||||
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||
if (fieldEntry.docsWithFieldOffset == -2) {
|
||||
if (fieldEntry.docsWithFieldOffset == -2
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||
int byteSize =
|
||||
switch (fieldEntry.vectorEncoding) {
|
||||
case BYTE -> fieldEntry.dimension;
|
||||
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
|
||||
};
|
||||
int byteSize = fieldEntry.dimension * Float.BYTES;
|
||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||
return new DenseOffHeapVectorValues(
|
||||
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||
|
@ -143,7 +131,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -219,7 +207,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -291,7 +279,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
@ -300,11 +288,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
|
@ -255,6 +256,16 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||
if (knnVectorsReader == null) {
|
||||
return null;
|
||||
} else {
|
||||
return knnVectorsReader.getByteVectorValues(field);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
|
|
@ -0,0 +1,163 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.document;
|
||||
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
|
||||
* - that is, every dimension of a vector contains an explicit value, stored packed into an array
|
||||
* (of type byte[]) whose length is the vector dimension. Values can be retrieved using {@link
|
||||
* ByteVectorValues}, which is a forward-only docID-based iterator and also offers random-access by
|
||||
* dense ordinal (not docId). {@link VectorSimilarityFunction} may be used to compare vectors at
|
||||
* query time (for example as part of result ranking). A KnnByteVectorField may be associated with a
|
||||
* search similarity function defining the metric used for nearest-neighbor search among vectors of
|
||||
* that field.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class KnnByteVectorField extends Field {
|
||||
|
||||
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
int dimension = v.length;
|
||||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
|
||||
throw new IllegalArgumentException(
|
||||
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
|
||||
}
|
||||
if (similarityFunction == null) {
|
||||
throw new IllegalArgumentException("similarity function must not be null");
|
||||
}
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new vector query for the provided field targeting the byte vector
|
||||
*
|
||||
* @param field The field to query
|
||||
* @param queryVector The byte vector target
|
||||
* @param k The number of nearest neighbors to gather
|
||||
* @return A new vector query
|
||||
*/
|
||||
public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
|
||||
return new KnnByteVectorQuery(field, queryVector, k);
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method for creating a vector field type.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||
*/
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||
* be constant-length.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(
|
||||
String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||
* the same dimension and similarity function.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(String name, BytesRef vector) {
|
||||
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param fieldType field type
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to create a vector for field "
|
||||
+ name
|
||||
+ " using byte[] but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
public BytesRef vectorValue() {
|
||||
return (BytesRef) fieldsData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the vector value of this field
|
||||
*
|
||||
* @param value the value to set; must not be null, and length must match the field type
|
||||
*/
|
||||
public void setVectorValue(BytesRef value) {
|
||||
if (value == null) {
|
||||
throw new IllegalArgumentException("value must not be null");
|
||||
}
|
||||
if (value.length != type.vectorDimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"value length " + value.length + " must match field dimension " + type.vectorDimension());
|
||||
}
|
||||
fieldsData = value;
|
||||
}
|
||||
}
|
|
@ -20,7 +20,8 @@ package org.apache.lucene.document;
|
|||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
|
@ -41,18 +42,7 @@ public class KnnVectorField extends Field {
|
|||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
|
||||
}
|
||||
|
||||
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||
if (v == null) {
|
||||
throw new IllegalArgumentException("vector value must not be null");
|
||||
}
|
||||
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
|
||||
}
|
||||
|
||||
private static FieldType createType(
|
||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||
int dimension = v.length;
|
||||
if (dimension == 0) {
|
||||
throw new IllegalArgumentException("cannot index an empty vector");
|
||||
}
|
||||
|
@ -64,13 +54,13 @@ public class KnnVectorField extends Field {
|
|||
throw new IllegalArgumentException("similarity function must not be null");
|
||||
}
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||
type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method for creating a vector field type with the default FLOAT32 encoding.
|
||||
* A convenience method for creating a vector field type.
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
|
@ -78,23 +68,22 @@ public class KnnVectorField extends Field {
|
|||
*/
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* A convenience method for creating a vector field type.
|
||||
* Create a new vector query for the provided field targeting the float vector
|
||||
*
|
||||
* @param dimension dimension of vectors
|
||||
* @param vectorEncoding the encoding of the scalar values
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||
* @param field The field to query
|
||||
* @param queryVector The float vector target
|
||||
* @param k The number of nearest neighbors to gather
|
||||
* @return A new vector query
|
||||
*/
|
||||
public static FieldType createFieldType(
|
||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
||||
FieldType type = new FieldType();
|
||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
||||
type.freeze();
|
||||
return type;
|
||||
public static Query newVectorQuery(String field, float[] queryVector, int k) {
|
||||
return new KnnVectorQuery(field, queryVector, k);
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -114,23 +103,6 @@ public class KnnVectorField extends Field {
|
|||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||
* be constant-length.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param similarityFunction a function defining vector proximity.
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
super(name, createType(vector, similarityFunction));
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||
|
@ -167,28 +139,6 @@ public class KnnVectorField extends Field {
|
|||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
* @param fieldType field type
|
||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||
* dimension > 1024.
|
||||
*/
|
||||
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||
super(name, fieldType);
|
||||
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"Attempt to create a vector for field "
|
||||
+ name
|
||||
+ " using BytesRef but the field encoding is "
|
||||
+ fieldType.vectorEncoding());
|
||||
}
|
||||
fieldsData = vector;
|
||||
}
|
||||
|
||||
/** Return the vector value of this field */
|
||||
public float[] vectorValue() {
|
||||
return (float[]) fieldsData;
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* This class provides access to per-document floating point vector values indexed as {@link
|
||||
* KnnByteVectorField}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract class ByteVectorValues extends DocIdSetIterator {
|
||||
|
||||
/** The maximum length of a vector */
|
||||
public static final int MAX_DIMENSIONS = 1024;
|
||||
|
||||
/** Sole constructor */
|
||||
protected ByteVectorValues() {}
|
||||
|
||||
/** Return the dimension of the vectors */
|
||||
public abstract int dimension();
|
||||
|
||||
/**
|
||||
* Return the number of vectors for this field.
|
||||
*
|
||||
* @return the number of vectors returned by this iterator
|
||||
*/
|
||||
public abstract int size();
|
||||
|
||||
@Override
|
||||
public final long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the vector value for the current document ID. It is illegal to call this method when the
|
||||
* iterator is not positioned: before advancing, or after failing to advance. The returned array
|
||||
* may be shared across calls, re-used, and modified as the iterator advances.
|
||||
*
|
||||
* @return the vector value
|
||||
*/
|
||||
public abstract BytesRef vectorValue() throws IOException;
|
||||
|
||||
/**
|
||||
* Return the binary encoded vector value for the current document ID. These are the bytes
|
||||
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
|
||||
* method when the iterator is not positioned: before advancing, or after failing to advance. The
|
||||
* returned storage may be shared across calls, re-used and modified as the iterator advances.
|
||||
*
|
||||
* @return the binary value
|
||||
*/
|
||||
public final BytesRef binaryValue() throws IOException {
|
||||
return vectorValue();
|
||||
}
|
||||
}
|
|
@ -34,6 +34,7 @@ import java.util.Iterator;
|
|||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.concurrent.Callable;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
|
@ -2588,62 +2589,37 @@ public final class CheckIndex implements Closeable {
|
|||
+ "\" has vector values but dimension is "
|
||||
+ dimension);
|
||||
}
|
||||
VectorValues values = reader.getVectorValues(fieldInfo.name);
|
||||
if (values == null) {
|
||||
if (reader.getVectorValues(fieldInfo.name) == null
|
||||
&& reader.getByteVectorValues(fieldInfo.name) == null) {
|
||||
continue;
|
||||
}
|
||||
|
||||
status.totalKnnVectorFields++;
|
||||
|
||||
int docCount = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
TopDocs docs =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> reader
|
||||
.getVectorReader()
|
||||
.search(
|
||||
fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||
case BYTE -> reader
|
||||
.getVectorReader()
|
||||
.search(
|
||||
fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
|
||||
};
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
float[] vectorValue = values.vectorValue();
|
||||
int valueLength = vectorValue.length;
|
||||
if (valueLength != dimension) {
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE:
|
||||
checkByteVectorValues(
|
||||
Objects.requireNonNull(reader.getByteVectorValues(fieldInfo.name)),
|
||||
fieldInfo,
|
||||
status,
|
||||
reader);
|
||||
break;
|
||||
case FLOAT32:
|
||||
checkFloatVectorValues(
|
||||
Objects.requireNonNull(reader.getVectorValues(fieldInfo.name)),
|
||||
fieldInfo,
|
||||
status,
|
||||
reader);
|
||||
break;
|
||||
default:
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has a value whose dimension="
|
||||
+ valueLength
|
||||
+ " not matching the field's dimension="
|
||||
+ dimension);
|
||||
}
|
||||
++docCount;
|
||||
+ "\" has unexpected vector encoding: "
|
||||
+ fieldInfo.getVectorEncoding());
|
||||
}
|
||||
if (docCount != values.size()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has size="
|
||||
+ values.size()
|
||||
+ " but when iterated, returns "
|
||||
+ docCount
|
||||
+ " docs with values");
|
||||
}
|
||||
status.totalVectorValues += docCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
msg(
|
||||
infoStream,
|
||||
String.format(
|
||||
|
@ -2667,6 +2643,96 @@ public final class CheckIndex implements Closeable {
|
|||
return status;
|
||||
}
|
||||
|
||||
private static void checkFloatVectorValues(
|
||||
VectorValues values,
|
||||
FieldInfo fieldInfo,
|
||||
CheckIndex.Status.VectorValuesStatus status,
|
||||
CodecReader codecReader)
|
||||
throws IOException {
|
||||
int docCount = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
TopDocs docs =
|
||||
codecReader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
int valueLength = values.vectorValue().length;
|
||||
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has a value whose dimension="
|
||||
+ valueLength
|
||||
+ " not matching the field's dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
++docCount;
|
||||
}
|
||||
if (docCount != values.size()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has size="
|
||||
+ values.size()
|
||||
+ " but when iterated, returns "
|
||||
+ docCount
|
||||
+ " docs with values");
|
||||
}
|
||||
status.totalVectorValues += docCount;
|
||||
}
|
||||
|
||||
private static void checkByteVectorValues(
|
||||
ByteVectorValues values,
|
||||
FieldInfo fieldInfo,
|
||||
CheckIndex.Status.VectorValuesStatus status,
|
||||
CodecReader codecReader)
|
||||
throws IOException {
|
||||
int docCount = 0;
|
||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||
// search the first maxNumSearches vectors to exercise the graph
|
||||
if (values.docID() % everyNdoc == 0) {
|
||||
TopDocs docs =
|
||||
codecReader
|
||||
.getVectorReader()
|
||||
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||
if (docs.scoreDocs.length == 0) {
|
||||
throw new CheckIndexException(
|
||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||
}
|
||||
}
|
||||
int valueLength = values.vectorValue().length;
|
||||
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has a value whose dimension="
|
||||
+ valueLength
|
||||
+ " not matching the field's dimension="
|
||||
+ fieldInfo.getVectorDimension());
|
||||
}
|
||||
++docCount;
|
||||
}
|
||||
if (docCount != values.size()) {
|
||||
throw new CheckIndexException(
|
||||
"Field \""
|
||||
+ fieldInfo.name
|
||||
+ "\" has size="
|
||||
+ values.size()
|
||||
+ " but when iterated, returns "
|
||||
+ docCount
|
||||
+ " docs with values");
|
||||
}
|
||||
status.totalVectorValues += docCount;
|
||||
}
|
||||
|
||||
/**
|
||||
* Walks the entire N-dimensional points space, verifying that all points fall within the last
|
||||
* cell's boundaries.
|
||||
|
|
|
@ -218,7 +218,9 @@ public abstract class CodecReader extends LeafReader {
|
|||
public final VectorValues getVectorValues(String field) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorDimension() == 0) {
|
||||
if (fi == null
|
||||
|| fi.getVectorDimension() == 0
|
||||
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
// Field does not exist or does not index vectors
|
||||
return null;
|
||||
}
|
||||
|
@ -226,6 +228,20 @@ public abstract class CodecReader extends LeafReader {
|
|||
return getVectorReader().getVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
ensureOpen();
|
||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||
if (fi == null
|
||||
|| fi.getVectorDimension() == 0
|
||||
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
// Field does not exist or does not index vectors
|
||||
return null;
|
||||
}
|
||||
|
||||
return getVectorReader().getByteVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
|
|
|
@ -53,6 +53,11 @@ abstract class DocValuesLeafReader extends LeafReader {
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
|
|
|
@ -323,6 +323,15 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return new ExitableVectorValues(vectorValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
final ByteVectorValues vectorValues = in.getByteVectorValues(field);
|
||||
if (vectorValues == null) {
|
||||
return null;
|
||||
}
|
||||
return new ExitableByteVectorValues(vectorValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
|
@ -387,17 +396,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
}
|
||||
}
|
||||
|
||||
private class ExitableVectorValues extends FilterVectorValues {
|
||||
private class ExitableVectorValues extends VectorValues {
|
||||
private int docToCheck;
|
||||
private final VectorValues vectorValues;
|
||||
|
||||
public ExitableVectorValues(VectorValues vectorValues) {
|
||||
super(vectorValues);
|
||||
this.vectorValues = vectorValues;
|
||||
docToCheck = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
final int advance = super.advance(target);
|
||||
final int advance = vectorValues.advance(target);
|
||||
if (advance >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
|
@ -405,9 +415,14 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return advance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return vectorValues.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
final int nextDoc = super.nextDoc();
|
||||
final int nextDoc = vectorValues.nextDoc();
|
||||
if (nextDoc >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
|
@ -415,14 +430,91 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
|||
return nextDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return vectorValues.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return in.vectorValue();
|
||||
return vectorValues.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectorValues.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
return in.binaryValue();
|
||||
return vectorValues.binaryValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||
* if {@link Thread#interrupted()} returns true.
|
||||
*/
|
||||
private void checkAndThrow() {
|
||||
if (queryTimeout.shouldExit()) {
|
||||
throw new ExitingReaderException(
|
||||
"The request took too long to iterate over vector values. Timeout: "
|
||||
+ queryTimeout.toString()
|
||||
+ ", VectorValues="
|
||||
+ in);
|
||||
} else if (Thread.interrupted()) {
|
||||
throw new ExitingReaderException(
|
||||
"Interrupted while iterating over vector values. VectorValues=" + in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class ExitableByteVectorValues extends ByteVectorValues {
|
||||
private int docToCheck;
|
||||
private final ByteVectorValues vectorValues;
|
||||
|
||||
public ExitableByteVectorValues(ByteVectorValues vectorValues) {
|
||||
this.vectorValues = vectorValues;
|
||||
docToCheck = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
final int advance = vectorValues.advance(target);
|
||||
if (advance >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return advance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return vectorValues.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
final int nextDoc = vectorValues.nextDoc();
|
||||
if (nextDoc >= docToCheck) {
|
||||
checkAndThrow();
|
||||
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||
}
|
||||
return nextDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return vectorValues.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectorValues.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
return vectorValues.vectorValue();
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -351,6 +351,11 @@ public abstract class FilterLeafReader extends LeafReader {
|
|||
return in.getVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
return in.getByteVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.PointsFormat;
|
||||
import org.apache.lucene.codecs.PointsWriter;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.Sort;
|
||||
|
@ -721,11 +722,7 @@ final class IndexingChain implements Accountable {
|
|||
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
||||
}
|
||||
if (fieldType.vectorDimension() != 0) {
|
||||
switch (fieldType.vectorEncoding()) {
|
||||
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
|
||||
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
|
||||
docID, ((KnnVectorField) field).vectorValue());
|
||||
}
|
||||
indexVectorValue(docID, pf, fieldType.vectorEncoding(), field);
|
||||
}
|
||||
return indexedField;
|
||||
}
|
||||
|
@ -959,6 +956,18 @@ final class IndexingChain implements Accountable {
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void indexVectorValue(
|
||||
int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
|
||||
throws IOException {
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
|
||||
.addValue(docID, ((KnnByteVectorField) field).vectorValue());
|
||||
case FLOAT32 -> ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
|
||||
.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */
|
||||
private PerField getPerField(String name) {
|
||||
final int hashPos = name.hashCode() & hashMask;
|
||||
|
|
|
@ -208,6 +208,14 @@ public abstract class LeafReader extends IndexReader {
|
|||
*/
|
||||
public abstract VectorValues getVectorValues(String field) throws IOException;
|
||||
|
||||
/**
|
||||
* Returns {@link ByteVectorValues} for this field, or null if no {@link ByteVectorValues} were
|
||||
* indexed. The returned instance should only be used by a single thread.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||
|
|
|
@ -408,6 +408,13 @@ public class ParallelLeafReader extends LeafReader {
|
|||
return reader == null ? null : reader.getVectorValues(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
|
||||
ensureOpen();
|
||||
LeafReader reader = fieldToReader.get(fieldName);
|
||||
return reader == null ? null : reader.getByteVectorValues(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
|
|
|
@ -168,6 +168,11 @@ public final class SlowCodecReaderWrapper {
|
|||
return reader.getVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
return reader.getByteVectorValues(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
|
|
@ -222,34 +222,21 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
final FixedBitSet docsWithField;
|
||||
final float[][] vectors;
|
||||
final ByteBuffer vectorAsBytes;
|
||||
final BytesRef[] binaryVectors;
|
||||
|
||||
private int docId = -1;
|
||||
|
||||
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap, VectorEncoding encoding)
|
||||
throws IOException {
|
||||
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.size = delegate.size();
|
||||
this.dimension = delegate.dimension();
|
||||
docsWithField = new FixedBitSet(sortMap.size());
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
vectors = null;
|
||||
binaryVectors = new BytesRef[sortMap.size()];
|
||||
vectorAsBytes = null;
|
||||
} else {
|
||||
vectors = new float[sortMap.size()][];
|
||||
binaryVectors = null;
|
||||
vectorAsBytes =
|
||||
ByteBuffer.allocate(delegate.dimension() * encoding.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
}
|
||||
vectors = new float[sortMap.size()][];
|
||||
vectorAsBytes =
|
||||
ByteBuffer.allocate(delegate.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(doc);
|
||||
docsWithField.set(newDocID);
|
||||
if (encoding == VectorEncoding.BYTE) {
|
||||
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.binaryValue());
|
||||
} else {
|
||||
vectors[newDocID] = delegate.vectorValue().clone();
|
||||
}
|
||||
vectors[newDocID] = delegate.vectorValue().clone();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -265,12 +252,8 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
|
||||
@Override
|
||||
public BytesRef binaryValue() throws IOException {
|
||||
if (binaryVectors != null) {
|
||||
return binaryVectors[docId];
|
||||
} else {
|
||||
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
|
||||
return new BytesRef(vectorAsBytes.array());
|
||||
}
|
||||
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
|
||||
return new BytesRef(vectorAsBytes.array());
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -297,6 +280,60 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
}
|
||||
}
|
||||
|
||||
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||
final int size;
|
||||
final int dimension;
|
||||
final FixedBitSet docsWithField;
|
||||
final BytesRef[] binaryVectors;
|
||||
|
||||
private int docId = -1;
|
||||
|
||||
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.size = delegate.size();
|
||||
this.dimension = delegate.dimension();
|
||||
docsWithField = new FixedBitSet(sortMap.size());
|
||||
binaryVectors = new BytesRef[sortMap.size()];
|
||||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(doc);
|
||||
docsWithField.set(newDocID);
|
||||
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(docId + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() throws IOException {
|
||||
return binaryVectors[docId];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
if (target >= docsWithField.length()) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docId = docsWithField.nextSetBit(target);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
|
||||
* . If the reader is already sorted, this method might return the reader as-is.
|
||||
|
@ -465,9 +502,12 @@ public final class SortingCodecReader extends FilterCodecReader {
|
|||
|
||||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
|
||||
return new SortingVectorValues(
|
||||
delegate.getVectorValues(field), docMap, fi.getVectorEncoding());
|
||||
return new SortingVectorValues(delegate.getVectorValues(field), docMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
return new SortingByteVectorValues(delegate.getByteVectorValues(field), docMap);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -61,8 +61,8 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
|
||||
/**
|
||||
* Return the binary encoded vector value for the current document ID. These are the bytes
|
||||
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
|
||||
* method when the iterator is not positioned: before advancing, or after failing to advance. The
|
||||
* corresponding to the array return by {@link #vectorValue}. It is illegal to call this method
|
||||
* when the iterator is not positioned: before advancing, or after failing to advance. The
|
||||
* returned storage may be shared across calls, re-used and modified as the iterator advances.
|
||||
*
|
||||
* @return the binary value
|
||||
|
|
|
@ -31,7 +31,8 @@ import org.apache.lucene.index.Terms;
|
|||
|
||||
/**
|
||||
* A {@link Query} that matches documents that contain either a {@link
|
||||
* org.apache.lucene.document.KnnVectorField}, or a field that indexes norms or doc values.
|
||||
* org.apache.lucene.document.KnnVectorField}, {@link org.apache.lucene.document.KnnByteVectorField}
|
||||
* or a field that indexes norms or doc values.
|
||||
*/
|
||||
public class FieldExistsQuery extends Query {
|
||||
private String field;
|
||||
|
@ -127,7 +128,12 @@ public class FieldExistsQuery extends Query {
|
|||
break;
|
||||
}
|
||||
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
||||
if (leaf.getVectorValues(field).size() != leaf.maxDoc()) {
|
||||
int numVectors =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> leaf.getVectorValues(field).size();
|
||||
case BYTE -> leaf.getByteVectorValues(field).size();
|
||||
};
|
||||
if (numVectors != leaf.maxDoc()) {
|
||||
allReadersRewritable = false;
|
||||
break;
|
||||
}
|
||||
|
@ -175,7 +181,11 @@ public class FieldExistsQuery extends Query {
|
|||
if (fieldInfo.hasNorms()) { // the field indexes norms
|
||||
iterator = context.reader().getNormValues(field);
|
||||
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
||||
iterator = context.reader().getVectorValues(field);
|
||||
iterator =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case FLOAT32 -> context.reader().getVectorValues(field);
|
||||
case BYTE -> context.reader().getByteVectorValues(field);
|
||||
};
|
||||
} else if (fieldInfo.getDocValuesType()
|
||||
!= DocValuesType.NONE) { // the field indexes doc values
|
||||
switch (fieldInfo.getDocValuesType()) {
|
||||
|
|
|
@ -54,7 +54,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
* @param k the number of documents to find
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k) {
|
||||
public KnnByteVectorQuery(String field, BytesRef target, int k) {
|
||||
this(field, target, k, null);
|
||||
}
|
||||
|
||||
|
@ -68,9 +68,9 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
|||
* @param filter a filter applied before the vector search
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||
super(field, k, filter);
|
||||
this.target = new BytesRef(target);
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -29,7 +30,6 @@ import org.apache.lucene.util.BytesRef;
|
|||
* search over the vectors.
|
||||
*/
|
||||
abstract class VectorScorer {
|
||||
protected final VectorValues values;
|
||||
protected final VectorSimilarityFunction similarity;
|
||||
|
||||
/**
|
||||
|
@ -48,53 +48,72 @@ abstract class VectorScorer {
|
|||
|
||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
|
||||
throws IOException {
|
||||
VectorValues values = context.reader().getVectorValues(fi.name);
|
||||
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
|
||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||
return new ByteVectorScorer(values, query, similarity);
|
||||
}
|
||||
|
||||
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
|
||||
this.values = values;
|
||||
VectorScorer(VectorSimilarityFunction similarity) {
|
||||
this.similarity = similarity;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
/** Compute the similarity score for the current document. */
|
||||
abstract float score() throws IOException;
|
||||
|
||||
abstract boolean advanceExact(int doc) throws IOException;
|
||||
|
||||
private static class ByteVectorScorer extends VectorScorer {
|
||||
private final BytesRef query;
|
||||
private final ByteVectorValues values;
|
||||
|
||||
protected ByteVectorScorer(
|
||||
VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
||||
super(values, similarity);
|
||||
ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
||||
super(similarity);
|
||||
this.values = values;
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
return similarity.compare(query, values.binaryValue());
|
||||
return similarity.compare(query, values.vectorValue());
|
||||
}
|
||||
}
|
||||
|
||||
private static class FloatVectorScorer extends VectorScorer {
|
||||
private final float[] query;
|
||||
private final VectorValues values;
|
||||
|
||||
protected FloatVectorScorer(
|
||||
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||
super(values, similarity);
|
||||
super(similarity);
|
||||
this.query = query;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
/**
|
||||
* Advance the instance to the given document ID and return true if there is a value for that
|
||||
* document.
|
||||
*/
|
||||
@Override
|
||||
public boolean advanceExact(int doc) throws IOException {
|
||||
int vectorDoc = values.docID();
|
||||
if (vectorDoc < doc) {
|
||||
vectorDoc = values.advance(doc);
|
||||
}
|
||||
return vectorDoc == doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -53,7 +53,7 @@ public final class HnswGraphBuilder<T> {
|
|||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
private final RandomAccessVectorValues vectors;
|
||||
private final RandomAccessVectorValues<T> vectors;
|
||||
private final SplittableRandom random;
|
||||
private final HnswGraphSearcher<T> graphSearcher;
|
||||
|
||||
|
@ -63,10 +63,10 @@ public final class HnswGraphBuilder<T> {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private final RandomAccessVectorValues vectorsCopy;
|
||||
private final RandomAccessVectorValues<T> vectorsCopy;
|
||||
|
||||
public static HnswGraphBuilder<?> create(
|
||||
RandomAccessVectorValues vectors,
|
||||
public static <T> HnswGraphBuilder<T> create(
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
|
@ -89,7 +89,7 @@ public final class HnswGraphBuilder<T> {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
private HnswGraphBuilder(
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
|
@ -131,7 +131,7 @@ public final class HnswGraphBuilder<T> {
|
|||
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
|
||||
* independent accessor for the vectors
|
||||
*/
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
|
||||
if (vectorsToAdd == this.vectors) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
|
@ -143,7 +143,7 @@ public final class HnswGraphBuilder<T> {
|
|||
return hnsw;
|
||||
}
|
||||
|
||||
private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
||||
private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectorsToAdd.size(); node++) {
|
||||
|
@ -189,16 +189,8 @@ public final class HnswGraphBuilder<T> {
|
|||
}
|
||||
}
|
||||
|
||||
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
|
||||
addGraphNode(node, getValue(node, values));
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> (T) values.binaryValue(node);
|
||||
case FLOAT32 -> (T) values.vectorValue(node);
|
||||
};
|
||||
public void addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException {
|
||||
addGraphNode(node, values.vectorValue(node));
|
||||
}
|
||||
|
||||
private long printGraphBuildStatus(int node, long start, long t) {
|
||||
|
@ -281,8 +273,8 @@ public final class HnswGraphBuilder<T> {
|
|||
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score);
|
||||
case BYTE -> isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -290,7 +282,8 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
similarityFunction.compare(
|
||||
candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
|
@ -302,7 +295,8 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
similarityFunction.compare(
|
||||
candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
|
@ -327,9 +321,10 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
int candidateNode = neighbors.node[candidateIndex];
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
|
||||
case BYTE -> isWorstNonDiverse(
|
||||
candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
|
||||
case FLOAT32 -> isWorstNonDiverse(
|
||||
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
|
||||
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -338,7 +333,8 @@ public final class HnswGraphBuilder<T> {
|
|||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
similarityFunction.compare(
|
||||
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// candidate node is too similar to node i given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return true;
|
||||
|
@ -352,7 +348,8 @@ public final class HnswGraphBuilder<T> {
|
|||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
similarityFunction.compare(
|
||||
candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// candidate node is too similar to node i given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return true;
|
||||
|
|
|
@ -81,7 +81,7 @@ public class HnswGraphSearcher<T> {
|
|||
public static NeighborQueue search(
|
||||
float[] query,
|
||||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<float[]> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
|
@ -137,7 +137,7 @@ public class HnswGraphSearcher<T> {
|
|||
public static NeighborQueue search(
|
||||
BytesRef query,
|
||||
int topK,
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<BytesRef> vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
HnswGraph graph,
|
||||
|
@ -198,7 +198,7 @@ public class HnswGraphSearcher<T> {
|
|||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
HnswGraph graph)
|
||||
throws IOException {
|
||||
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
|
||||
|
@ -209,7 +209,7 @@ public class HnswGraphSearcher<T> {
|
|||
int topK,
|
||||
int level,
|
||||
final int[] eps,
|
||||
RandomAccessVectorValues vectors,
|
||||
RandomAccessVectorValues<T> vectors,
|
||||
HnswGraph graph,
|
||||
Bits acceptOrds,
|
||||
int visitedLimit)
|
||||
|
@ -279,11 +279,11 @@ public class HnswGraphSearcher<T> {
|
|||
return results;
|
||||
}
|
||||
|
||||
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
|
||||
private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
|
||||
return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
|
||||
} else {
|
||||
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
|
||||
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,7 +18,6 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
/**
|
||||
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
||||
|
@ -26,7 +25,7 @@ import org.apache.lucene.util.BytesRef;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface RandomAccessVectorValues {
|
||||
public interface RandomAccessVectorValues<T> {
|
||||
|
||||
/** Return the number of vector values */
|
||||
int size();
|
||||
|
@ -35,26 +34,16 @@ public interface RandomAccessVectorValues {
|
|||
int dimension();
|
||||
|
||||
/**
|
||||
* Return the vector value indexed at the given ordinal. The provided floating point array may be
|
||||
* shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}.
|
||||
* Return the vector value indexed at the given ordinal.
|
||||
*
|
||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||
*/
|
||||
float[] vectorValue(int targetOrd) throws IOException;
|
||||
|
||||
/**
|
||||
* Return the vector indexed at the given ordinal value as an array of bytes in a BytesRef; these
|
||||
* are the bytes corresponding to the float array. The provided bytes may be shared and
|
||||
* overwritten by subsequent calls to this method and {@link #vectorValue(int)}.
|
||||
*
|
||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||
*/
|
||||
BytesRef binaryValue(int targetOrd) throws IOException;
|
||||
T vectorValue(int targetOrd) throws IOException;
|
||||
|
||||
/**
|
||||
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
|
||||
* access different values at once, to avoid overwriting the underlying float vector returned by
|
||||
* {@link RandomAccessVectorValues#vectorValue}.
|
||||
*/
|
||||
RandomAccessVectorValues copy() throws IOException;
|
||||
RandomAccessVectorValues<T> copy() throws IOException;
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|||
import java.io.StringReader;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
|
@ -611,25 +612,22 @@ public class TestField extends LuceneTestCase {
|
|||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
Document doc = new Document();
|
||||
BytesRef br = newBytesRef(new byte[5]);
|
||||
Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
||||
Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
|
||||
float[] vector = new float[] {1, 2};
|
||||
Field field2 = new KnnVectorField("float", vector);
|
||||
expectThrows(
|
||||
IllegalArgumentException.class,
|
||||
() -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
|
||||
assertEquals(br, field.binaryValue());
|
||||
doc.add(field);
|
||||
doc.add(field2);
|
||||
w.addDocument(doc);
|
||||
try (IndexReader r = DirectoryReader.open(w)) {
|
||||
VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary");
|
||||
ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
|
||||
assertEquals(1, binary.size());
|
||||
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
assertEquals(br, binary.binaryValue());
|
||||
assertNotNull(binary.vectorValue());
|
||||
assertEquals(br, binary.vectorValue());
|
||||
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||
|
||||
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
|
||||
|
|
|
@ -112,6 +112,11 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -27,12 +27,12 @@ import org.apache.lucene.util.TestVectorUtil;
|
|||
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||
@Override
|
||||
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
|
||||
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
|
||||
return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
|
||||
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
|
||||
return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -49,12 +49,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
@Override
|
||||
Field getKnnVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
||||
return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field getKnnVectorField(String name, float[] vector) {
|
||||
return new KnnVectorField(
|
||||
return new KnnByteVectorField(
|
||||
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
|
||||
}
|
||||
|
||||
|
@ -80,7 +80,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
|||
|
||||
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
|
||||
|
||||
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
|
||||
public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||
super(field, target, k, filter);
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
|||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -79,7 +80,7 @@ public class TestVectorScorer extends LuceneTestCase {
|
|||
for (int j = 0; j < v.length; j++) {
|
||||
v.bytes[j] = (byte) contents[i][j];
|
||||
}
|
||||
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
|
||||
doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
|
||||
} else {
|
||||
doc.add(new KnnVectorField(field, contents[i]));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||
|
||||
protected final int dimension;
|
||||
protected final T[] denseValues;
|
||||
protected final T[] values;
|
||||
protected final int numVectors;
|
||||
protected final BytesRef binaryValue;
|
||||
|
||||
protected int pos = -1;
|
||||
|
||||
AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
|
||||
this.dimension = dimension;
|
||||
this.values = values;
|
||||
this.denseValues = denseValues;
|
||||
// used by tests that build a graph from bytes rather than floats
|
||||
binaryValue = new BytesRef(dimension);
|
||||
binaryValue.length = dimension;
|
||||
this.numVectors = numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T vectorValue(int targetOrd) {
|
||||
return denseValues[targetOrd];
|
||||
}
|
||||
|
||||
@Override
|
||||
public abstract AbstractMockVectorValues<T> copy();
|
||||
|
||||
public abstract T vectorValue() throws IOException;
|
||||
|
||||
private boolean seek(int target) {
|
||||
if (target >= 0 && target < values.length && values[target] != null) {
|
||||
pos = target;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public int docID() {
|
||||
return pos;
|
||||
}
|
||||
|
||||
public int nextDoc() {
|
||||
return advance(pos + 1);
|
||||
}
|
||||
|
||||
public int advance(int target) {
|
||||
while (++pos < values.length) {
|
||||
if (seek(pos)) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
|
|||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
|
||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
|
@ -36,21 +35,23 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
|||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.StoredFields;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
|
@ -65,19 +66,30 @@ import org.apache.lucene.util.FixedBitSet;
|
|||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.junit.Before;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswGraph extends LuceneTestCase {
|
||||
abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||
|
||||
VectorSimilarityFunction similarityFunction;
|
||||
VectorEncoding vectorEncoding;
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
||||
}
|
||||
abstract VectorEncoding getVectorEncoding();
|
||||
|
||||
abstract Query knnQuery(String field, T vector, int k);
|
||||
|
||||
abstract T randomVector(int dim);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
|
||||
|
||||
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException;
|
||||
|
||||
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
|
||||
|
||||
abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
|
||||
|
||||
abstract T getTargetVector();
|
||||
|
||||
// test writing out and reading in a graph gives the expected graph
|
||||
public void testReadWrite() throws IOException {
|
||||
|
@ -86,10 +98,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
int M = random().nextInt(4) + 2;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
long seed = random().nextLong();
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
|
||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||
AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
|
@ -115,7 +128,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
indexedDoc++;
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction));
|
||||
doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
|
||||
doc.add(new StoredField("id", v2.docID()));
|
||||
iw.addDocument(doc);
|
||||
nVec++;
|
||||
|
@ -124,7 +137,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
VectorValues values = ctx.reader().getVectorValues("field");
|
||||
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
|
||||
assertEquals(dim, values.dimension());
|
||||
assertEquals(nVec, values.size());
|
||||
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||
|
@ -142,15 +155,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private VectorEncoding randomVectorEncoding() {
|
||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
||||
}
|
||||
|
||||
// test that sorted index returns the same search results are unsorted
|
||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||
int dim = random().nextInt(10) + 3;
|
||||
int nDoc = random().nextInt(200) + 100;
|
||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||
|
||||
int M = random().nextInt(10) + 5;
|
||||
int beamWidth = random().nextInt(10) + 5;
|
||||
|
@ -190,7 +199,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
indexedDoc++;
|
||||
}
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
||||
doc.add(new StoredField("id", vectors.docID()));
|
||||
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
|
||||
iw.addDocument(doc);
|
||||
|
@ -206,7 +215,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
for (int i = 0; i < 10; i++) {
|
||||
// ask to explore a lot of candidates to ensure the same returned hits,
|
||||
// as graphs of 2 indices are organized differently
|
||||
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50);
|
||||
Query query = knnQuery("vector", randomVector(dim), 50);
|
||||
List<String> ids1 = new ArrayList<>();
|
||||
List<Integer> docs1 = new ArrayList<>();
|
||||
List<String> ids2 = new ArrayList<>();
|
||||
|
@ -241,7 +250,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||
void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||
|
||||
|
@ -271,32 +280,32 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
// Make sure we actually approximately find the closest k elements. Mostly this is about
|
||||
// ensuring that we have all the distance functions, comparators, priority queues and so on
|
||||
// oriented in the right directions
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testAknnDiverse() throws IOException {
|
||||
int nDoc = 100;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder<?> builder =
|
||||
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
switch (vectorEncoding) {
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
(BytesRef) getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
(float[]) getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
null,
|
||||
|
@ -323,33 +332,33 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSearchWithAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, nDoc);
|
||||
NeighborQueue nn =
|
||||
switch (vectorEncoding) {
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
(BytesRef) getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
(float[]) getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
|
@ -367,39 +376,39 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||
int nDoc = 100;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// Only mark a few vectors as accepted
|
||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
|
||||
BitSet acceptOrds = new FixedBitSet(nDoc);
|
||||
for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) {
|
||||
acceptOrds.set(i);
|
||||
}
|
||||
|
||||
// Check the search finds all accepted vectors
|
||||
int numAccepted = acceptOrds.cardinality();
|
||||
NeighborQueue nn =
|
||||
switch (vectorEncoding) {
|
||||
switch (getVectorEncoding()) {
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
(float[]) getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
(BytesRef) getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
|
@ -413,81 +422,37 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
private float[] getTargetVector() {
|
||||
return new float[] {1, 0};
|
||||
}
|
||||
|
||||
private BytesRef getTargetByteVector() {
|
||||
return new BytesRef(new byte[] {1, 0});
|
||||
}
|
||||
|
||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
// Skip over half of the documents that are closest to the query vector
|
||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||
for (int i = 500; i < nDoc; i++) {
|
||||
acceptOrds.set(i);
|
||||
}
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
VectorEncoding.FLOAT32,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
for (int node : nodes) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
sum += node;
|
||||
}
|
||||
// We still expect to get reasonable recall. The lowest non-skipped docIds
|
||||
// are closest to the query vector: sum(500,509) = 5045
|
||||
assertTrue("sum(result docs)=" + sum, sum < 5100);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testVisitedLimit() throws IOException {
|
||||
int nDoc = 500;
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
||||
HnswGraphBuilder<?> builder =
|
||||
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
int topK = 50;
|
||||
int visitedLimit = topK + random().nextInt(5);
|
||||
NeighborQueue nn =
|
||||
switch (vectorEncoding) {
|
||||
switch (getVectorEncoding()) {
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
(float[]) getTargetVector(),
|
||||
topK,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
createRandomAcceptOrds(0, nDoc),
|
||||
visitedLimit);
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
getTargetByteVector(),
|
||||
(BytesRef) getTargetVector(),
|
||||
topK,
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
createRandomAcceptOrds(0, vectors.size),
|
||||
createRandomAcceptOrds(0, nDoc),
|
||||
visitedLimit);
|
||||
};
|
||||
|
||||
|
@ -504,8 +469,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HnswGraphBuilder.create(
|
||||
new RandomVectorValues(1, 1, random()),
|
||||
VectorEncoding.FLOAT32,
|
||||
vectorValues(1, 1),
|
||||
getVectorEncoding(),
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
0,
|
||||
10,
|
||||
|
@ -515,8 +480,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
IllegalArgumentException.class,
|
||||
() ->
|
||||
HnswGraphBuilder.create(
|
||||
new RandomVectorValues(1, 1, random()),
|
||||
VectorEncoding.FLOAT32,
|
||||
vectorValues(1, 1),
|
||||
getVectorEncoding(),
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
10,
|
||||
0,
|
||||
|
@ -530,13 +495,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
VectorSimilarityFunction similarityFunction =
|
||||
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
||||
TestHnswGraph.RandomVectorValues vectors =
|
||||
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
|
||||
RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
|
||||
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
|
||||
vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
|
||||
long actual = ramUsed(hnsw);
|
||||
|
@ -546,7 +509,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testDiversity() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
||||
float[][] values = {
|
||||
|
@ -558,21 +520,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
unitVector2d(0.77),
|
||||
unitVector2d(0.6)
|
||||
};
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (float[] v : values) {
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
v[i] *= 127;
|
||||
}
|
||||
}
|
||||
}
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
// now every node has tried to attach every other node as a neighbor, but
|
||||
|
@ -609,7 +564,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testDiversityFallback() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
// Some test cases can't be exercised in two dimensions;
|
||||
// in particular if a new neighbor displaces an existing neighbor
|
||||
|
@ -622,14 +576,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
{10, 0, 0},
|
||||
{0, 4, 0}
|
||||
};
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
|
@ -647,7 +601,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
|
||||
public void testDiversity3d() throws IOException {
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
|
||||
float[][] values = {
|
||||
|
@ -656,14 +609,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
{0, 0, 20},
|
||||
{0, 9, 0}
|
||||
};
|
||||
MockVectorValues vectors = new MockVectorValues(values);
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||
// First add nodes until everybody gets a full neighbor list
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
||||
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
|
@ -691,44 +644,38 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
actual);
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
public void testRandom() throws IOException {
|
||||
int size = atLeast(100);
|
||||
int dim = atLeast(10);
|
||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
|
||||
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||
int topK = 5;
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder<T> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
|
||||
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||
|
||||
int totalMatches = 0;
|
||||
for (int i = 0; i < 100; i++) {
|
||||
NeighborQueue actual;
|
||||
float[] query;
|
||||
BytesRef bQuery = null;
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
query = randomVector8(random(), dim);
|
||||
bQuery = toBytesRef(query);
|
||||
} else {
|
||||
query = randomVector(random(), dim);
|
||||
}
|
||||
T query = randomVector(dim);
|
||||
actual =
|
||||
switch (vectorEncoding) {
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE -> HnswGraphSearcher.search(
|
||||
bQuery,
|
||||
(BytesRef) query,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<BytesRef>) vectors,
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
case FLOAT32 -> HnswGraphSearcher.search(
|
||||
query,
|
||||
(float[]) query,
|
||||
100,
|
||||
vectors,
|
||||
vectorEncoding,
|
||||
(RandomAccessVectorValues<float[]>) vectors,
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
|
@ -741,10 +688,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||
for (int j = 0; j < size; j++) {
|
||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
|
||||
if (getVectorEncoding() == VectorEncoding.BYTE) {
|
||||
assert query instanceof BytesRef;
|
||||
expected.add(
|
||||
j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
|
||||
} else {
|
||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
||||
assert query instanceof float[];
|
||||
expected.add(
|
||||
j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
|
||||
}
|
||||
if (expected.size() > topK) {
|
||||
expected.pop();
|
||||
|
@ -778,17 +729,16 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
static class CircularVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues<float[]> {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final BytesRef binaryValue;
|
||||
|
||||
int doc = -1;
|
||||
|
||||
CircularVectorValues(int size) {
|
||||
this.size = size;
|
||||
value = new float[2];
|
||||
binaryValue = new BytesRef(new byte[2]);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -835,14 +785,70 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
public float[] vectorValue(int ord) {
|
||||
return unitVector2d(ord / (double) size, value);
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularByteVectorValues extends ByteVectorValues
|
||||
implements RandomAccessVectorValues<BytesRef> {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final BytesRef bValue;
|
||||
|
||||
int doc = -1;
|
||||
|
||||
CircularByteVectorValues(int size) {
|
||||
this.size = size;
|
||||
value = new float[2];
|
||||
bValue = new BytesRef(new byte[2]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int ord) {
|
||||
float[] vectorValue = vectorValue(ord);
|
||||
for (int i = 0; i < vectorValue.length; i++) {
|
||||
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
|
||||
public CircularByteVectorValues copy() {
|
||||
return new CircularByteVectorValues(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return 2;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
if (target >= 0 && target < size) {
|
||||
doc = target;
|
||||
} else {
|
||||
doc = NO_MORE_DOCS;
|
||||
}
|
||||
return binaryValue;
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue(int ord) {
|
||||
unitVector2d(ord / (double) size, value);
|
||||
for (int i = 0; i < value.length; i++) {
|
||||
bValue.bytes[i] = (byte) (value[i] * 127);
|
||||
}
|
||||
return bValue;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -864,7 +870,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return neighbors;
|
||||
}
|
||||
|
||||
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
|
||||
void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
|
||||
throws IOException {
|
||||
int uDoc, vDoc;
|
||||
while (true) {
|
||||
uDoc = u.nextDoc();
|
||||
|
@ -873,49 +880,40 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
if (uDoc == NO_MORE_DOCS) {
|
||||
break;
|
||||
}
|
||||
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
|
||||
switch (getVectorEncoding()) {
|
||||
case BYTE:
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
((BytesRef) u.vectorValue()).bytes,
|
||||
((BytesRef) v.vectorValue()).bytes);
|
||||
break;
|
||||
case FLOAT32:
|
||||
assertArrayEquals(
|
||||
"vectors do not match for doc=" + uDoc,
|
||||
(float[]) u.vectorValue(),
|
||||
(float[]) v.vectorValue(),
|
||||
1e-4f);
|
||||
break;
|
||||
default:
|
||||
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Produces random vectors and caches them for random-access. */
|
||||
static class RandomVectorValues extends MockVectorValues {
|
||||
|
||||
RandomVectorValues(int size, int dimension, Random random) {
|
||||
super(createRandomVectors(size, dimension, null, random));
|
||||
static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
|
||||
float[][] vectors = new float[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
vectors[offset] = randomVector(random, dimension);
|
||||
}
|
||||
return vectors;
|
||||
}
|
||||
|
||||
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||
super(createRandomVectors(size, dimension, vectorEncoding, random));
|
||||
}
|
||||
|
||||
RandomVectorValues(RandomVectorValues other) {
|
||||
super(other.values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorValues copy() {
|
||||
return new RandomVectorValues(this);
|
||||
}
|
||||
|
||||
private static float[][] createRandomVectors(
|
||||
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
||||
float[][] vectors = new float[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
vectors[offset] = randomVector(random, dimension);
|
||||
}
|
||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||
for (float[] vector : vectors) {
|
||||
if (vector != null) {
|
||||
for (int i = 0; i < vector.length; i++) {
|
||||
vector[i] = (byte) (127 * vector[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return vectors;
|
||||
static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
|
||||
byte[][] vectors = new byte[size][];
|
||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||
vectors[offset] = randomVector8(random, dimension);
|
||||
}
|
||||
return vectors;
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -937,7 +935,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return bits;
|
||||
}
|
||||
|
||||
private static float[] randomVector(Random random, int dim) {
|
||||
static float[] randomVector(Random random, int dim) {
|
||||
float[] vec = new float[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
vec[i] = random.nextFloat();
|
||||
|
@ -949,11 +947,12 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return vec;
|
||||
}
|
||||
|
||||
private static float[] randomVector8(Random random, int dim) {
|
||||
static byte[] randomVector8(Random random, int dim) {
|
||||
float[] fvec = randomVector(random, dim);
|
||||
byte[] bvec = new byte[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
fvec[i] *= 127;
|
||||
bvec[i] = (byte) (fvec[i] * 127);
|
||||
}
|
||||
return fvec;
|
||||
return bvec;
|
||||
}
|
||||
}
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
|||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
|
@ -704,7 +705,11 @@ public class KnnGraphTester {
|
|||
iwc.setUseCompoundFile(false);
|
||||
// iwc.setMaxBufferedDocs(10000);
|
||||
|
||||
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
|
||||
FieldType fieldType =
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
|
||||
case FLOAT32 -> KnnVectorField.createFieldType(dim, similarityFunction);
|
||||
};
|
||||
if (quiet == false) {
|
||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||
System.out.println("creating index in " + indexPath);
|
||||
|
@ -718,7 +723,7 @@ public class KnnGraphTester {
|
|||
Document doc = new Document();
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> doc.add(
|
||||
new KnnVectorField(
|
||||
new KnnByteVectorField(
|
||||
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
|
||||
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
|
||||
}
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
||||
private final byte[] scratch;
|
||||
|
||||
static MockByteVectorValues fromValues(byte[][] byteValues) {
|
||||
int dimension = byteValues[0].length;
|
||||
BytesRef[] values = new BytesRef[byteValues.length];
|
||||
for (int i = 0; i < byteValues.length; i++) {
|
||||
values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
|
||||
}
|
||||
BytesRef[] denseValues = new BytesRef[values.length];
|
||||
int count = 0;
|
||||
for (int i = 0; i < byteValues.length; i++) {
|
||||
if (values[i] != null) {
|
||||
denseValues[count++] = values[i];
|
||||
}
|
||||
}
|
||||
return new MockByteVectorValues(values, dimension, denseValues, count);
|
||||
}
|
||||
|
||||
MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
|
||||
super(values, dimension, denseValues, numVectors);
|
||||
scratch = new byte[dimension];
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockByteVectorValues copy() {
|
||||
return new MockByteVectorValues(
|
||||
ArrayUtil.copyOfSubArray(values, 0, values.length),
|
||||
dimension,
|
||||
ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
|
||||
numVectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef vectorValue() {
|
||||
if (LuceneTestCase.random().nextBoolean()) {
|
||||
return values[pos];
|
||||
} else {
|
||||
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
|
||||
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
|
||||
// twice in a
|
||||
// single computation.
|
||||
System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
|
||||
return new BytesRef(scratch);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,52 +17,37 @@
|
|||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
|
||||
class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
class MockVectorValues extends AbstractMockVectorValues<float[]> {
|
||||
private final float[] scratch;
|
||||
|
||||
protected final int dimension;
|
||||
protected final float[][] denseValues;
|
||||
protected final float[][] values;
|
||||
private final int numVectors;
|
||||
private final BytesRef binaryValue;
|
||||
|
||||
private int pos = -1;
|
||||
|
||||
MockVectorValues(float[][] values) {
|
||||
this.dimension = values[0].length;
|
||||
this.values = values;
|
||||
static MockVectorValues fromValues(float[][] values) {
|
||||
int dimension = values[0].length;
|
||||
int maxDoc = values.length;
|
||||
denseValues = new float[maxDoc][];
|
||||
float[][] denseValues = new float[maxDoc][];
|
||||
int count = 0;
|
||||
for (int i = 0; i < maxDoc; i++) {
|
||||
if (values[i] != null) {
|
||||
denseValues[count++] = values[i];
|
||||
}
|
||||
}
|
||||
numVectors = count;
|
||||
scratch = new float[dimension];
|
||||
// used by tests that build a graph from bytes rather than floats
|
||||
binaryValue = new BytesRef(dimension);
|
||||
binaryValue.length = dimension;
|
||||
return new MockVectorValues(values, dimension, denseValues, count);
|
||||
}
|
||||
|
||||
MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
|
||||
super(values, dimension, denseValues, numVectors);
|
||||
this.scratch = new float[dimension];
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockVectorValues copy() {
|
||||
return new MockVectorValues(values);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return numVectors;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
return new MockVectorValues(
|
||||
ArrayUtil.copyOfSubArray(values, 0, values.length),
|
||||
dimension,
|
||||
ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
|
||||
numVectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -83,42 +68,4 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
|
|||
public float[] vectorValue(int targetOrd) {
|
||||
return denseValues[targetOrd];
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
float[] value = vectorValue(targetOrd);
|
||||
for (int i = 0; i < value.length; i++) {
|
||||
binaryValue.bytes[i] = (byte) value[i];
|
||||
}
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
private boolean seek(int target) {
|
||||
if (target >= 0 && target < values.length && values[target] != null) {
|
||||
pos = target;
|
||||
return true;
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return pos;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
return advance(pos + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
while (++pos < values.length) {
|
||||
if (seek(pos)) {
|
||||
return pos;
|
||||
}
|
||||
}
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,118 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
import org.junit.Before;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorEncoding getVectorEncoding() {
|
||||
return VectorEncoding.BYTE;
|
||||
}
|
||||
|
||||
@Override
|
||||
Query knnQuery(String field, BytesRef vector, int k) {
|
||||
return new KnnByteVectorQuery(field, vector, k);
|
||||
}
|
||||
|
||||
@Override
|
||||
BytesRef randomVector(int dim) {
|
||||
return new BytesRef(randomVector8(random(), dim));
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
|
||||
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
|
||||
}
|
||||
|
||||
static boolean fitsInByte(float v) {
|
||||
return v <= 127 && v >= -128 && v % 1 == 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
|
||||
byte[][] bValues = new byte[values.length][];
|
||||
// The case when all floats fit within a byte already.
|
||||
boolean scaleSimple = fitsInByte(values[0][0]);
|
||||
for (int i = 0; i < values.length; i++) {
|
||||
bValues[i] = new byte[values[i].length];
|
||||
for (int j = 0; j < values[i].length; j++) {
|
||||
final float v;
|
||||
if (scaleSimple) {
|
||||
assert fitsInByte(values[i][j]);
|
||||
v = values[i][j];
|
||||
} else {
|
||||
v = values[i][j] * 127;
|
||||
}
|
||||
bValues[i][j] = (byte) v;
|
||||
}
|
||||
}
|
||||
return MockByteVectorValues.fromValues(bValues);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
|
||||
byte[][] vectors = new byte[reader.maxDoc()][];
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
vectors[vectorValues.docID()] =
|
||||
ArrayUtil.copyOfSubArray(
|
||||
vectorValues.vectorValue().bytes,
|
||||
vectorValues.vectorValue().offset,
|
||||
vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
|
||||
}
|
||||
return MockByteVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnByteVectorField(name, vector, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
|
||||
return new CircularByteVectorValues(nDoc);
|
||||
}
|
||||
|
||||
@Override
|
||||
BytesRef getTargetVector() {
|
||||
return new BytesRef(new byte[] {1, 0});
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.KnnVectorQuery;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
import org.junit.Before;
|
||||
|
||||
/** Tests HNSW KNN graphs */
|
||||
public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorEncoding getVectorEncoding() {
|
||||
return VectorEncoding.FLOAT32;
|
||||
}
|
||||
|
||||
@Override
|
||||
Query knnQuery(String field, float[] vector, int k) {
|
||||
return new KnnVectorQuery(field, vector, k);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] randomVector(int dim) {
|
||||
return randomVector(random(), dim);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
|
||||
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
|
||||
return MockVectorValues.fromValues(values);
|
||||
}
|
||||
|
||||
@Override
|
||||
AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
|
||||
throws IOException {
|
||||
VectorValues vectorValues = reader.getVectorValues(fieldName);
|
||||
float[][] vectors = new float[reader.maxDoc()][];
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
vectors[vectorValues.docID()] =
|
||||
ArrayUtil.copyOfSubArray(
|
||||
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
|
||||
}
|
||||
return MockVectorValues.fromValues(vectors);
|
||||
}
|
||||
|
||||
@Override
|
||||
Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||
return new KnnVectorField(name, vector, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
RandomAccessVectorValues<float[]> circularVectorValues(int nDoc) {
|
||||
return new CircularVectorValues(nDoc);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] getTargetVector() {
|
||||
return new float[] {1f, 0f};
|
||||
}
|
||||
|
||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||
int nDoc = 1000;
|
||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||
RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
|
||||
HnswGraphBuilder<float[]> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
// Skip over half of the documents that are closest to the query vector
|
||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||
for (int i = 500; i < nDoc; i++) {
|
||||
acceptOrds.set(i);
|
||||
}
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.copy(),
|
||||
getVectorEncoding(),
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
acceptOrds,
|
||||
Integer.MAX_VALUE);
|
||||
|
||||
int[] nodes = nn.nodes();
|
||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||
int sum = 0;
|
||||
for (int node : nodes) {
|
||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||
sum += node;
|
||||
}
|
||||
// We still expect to get reasonable recall. The lowest non-skipped docIds
|
||||
// are closest to the query vector: sum(500,509) = 5045
|
||||
assertTrue("sum(result docs)=" + sum, sum < 5100);
|
||||
}
|
||||
}
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
|||
import java.util.Collections;
|
||||
import java.util.Iterator;
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
|
@ -165,6 +166,11 @@ public class TermVectorLeafReader extends LeafReader {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String fieldName) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
|
|
@ -1395,6 +1395,11 @@ public class MemoryIndex {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String fieldName) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
|
@ -113,7 +114,9 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
@Override
|
||||
public VectorValues getVectorValues(String field) throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null && fi.getVectorDimension() > 0;
|
||||
assert fi != null
|
||||
&& fi.getVectorDimension() > 0
|
||||
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||
VectorValues values = delegate.getVectorValues(field);
|
||||
assert values != null;
|
||||
assert values.docID() == -1;
|
||||
|
@ -122,6 +125,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
|||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldInfo fi = fis.fieldInfo(field);
|
||||
assert fi != null
|
||||
&& fi.getVectorDimension() > 0
|
||||
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
|
||||
ByteVectorValues values = delegate.getByteVectorValues(field);
|
||||
assert values != null;
|
||||
assert values.docID() == -1;
|
||||
assert values.size() >= 0;
|
||||
assert values.dimension() > 0;
|
||||
return values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||
throws IOException {
|
||||
|
|
|
@ -28,10 +28,12 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.FieldType;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.NumericDocValuesField;
|
||||
import org.apache.lucene.document.StoredField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CheckIndex;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
|
@ -79,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
@Override
|
||||
protected void addRandomFields(Document doc) {
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
|
||||
case BYTE -> doc.add(
|
||||
new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
|
||||
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
||||
}
|
||||
}
|
||||
|
@ -628,9 +631,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
if (random().nextInt(100) == 17) {
|
||||
switch (fieldVectorEncodings[field]) {
|
||||
case BYTE -> {
|
||||
BytesRef b = randomVector8(fieldDims[field]);
|
||||
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
|
||||
fieldTotals[field] += b.bytes[b.offset];
|
||||
byte[] b = randomVector8(fieldDims[field]);
|
||||
doc.add(
|
||||
new KnnByteVectorField(
|
||||
fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
|
||||
fieldTotals[field] += b[0];
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
float[] v = randomVector(fieldDims[field]);
|
||||
|
@ -648,12 +653,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
int docCount = 0;
|
||||
double checksum = 0;
|
||||
String fieldName = "int" + field;
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
VectorValues vectors = ctx.reader().getVectorValues(fieldName);
|
||||
if (vectors != null) {
|
||||
docCount += vectors.size();
|
||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectors.vectorValue()[0];
|
||||
switch (fieldVectorEncodings[field]) {
|
||||
case BYTE -> {
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (byteVectorValues != null) {
|
||||
docCount += byteVectorValues.size();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||
if (vectorValues != null) {
|
||||
docCount += vectorValues.size();
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectorValues.vectorValue()[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -755,15 +775,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
LeafReader leaf = getOnlyLeafReader(reader);
|
||||
|
||||
StoredFields storedFields = leaf.storedFields();
|
||||
VectorValues vectorValues = leaf.getVectorValues(fieldName);
|
||||
ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
|
||||
assertEquals(2, vectorValues.dimension());
|
||||
assertEquals(3, vectorValues.size());
|
||||
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(-1f, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
||||
assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
|
||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||
}
|
||||
}
|
||||
|
@ -915,7 +935,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
for (int i = 0; i < numDoc; i++) {
|
||||
if (random().nextInt(7) != 3) {
|
||||
// usually index a vector value for a doc
|
||||
values[i] = randomVector8(dimension);
|
||||
values[i] = new BytesRef(randomVector8(dimension));
|
||||
++numValues;
|
||||
}
|
||||
if (random().nextBoolean() && values[i] != null) {
|
||||
|
@ -943,7 +963,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||
int valueCount = 0, totalSize = 0;
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||
if (vectorValues == null) {
|
||||
continue;
|
||||
}
|
||||
|
@ -951,7 +971,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
int docId;
|
||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
BytesRef v = vectorValues.binaryValue();
|
||||
BytesRef v = vectorValues.vectorValue();
|
||||
assertEquals(dimension, v.length);
|
||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||
int id = Integer.parseInt(idString);
|
||||
|
@ -1141,7 +1161,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
throws IOException {
|
||||
Document doc = new Document();
|
||||
if (vector != null) {
|
||||
doc.add(new KnnVectorField(field, vector, similarityFunction));
|
||||
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
|
||||
}
|
||||
doc.add(new NumericDocValuesField("sortkey", sortKey));
|
||||
String idString = Integer.toString(id);
|
||||
|
@ -1183,13 +1203,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
return v;
|
||||
}
|
||||
|
||||
private BytesRef randomVector8(int dim) {
|
||||
private byte[] randomVector8(int dim) {
|
||||
float[] v = randomVector(dim);
|
||||
byte[] b = new byte[dim];
|
||||
for (int i = 0; i < dim; i++) {
|
||||
b[i] = (byte) (v[i] * 127);
|
||||
}
|
||||
return new BytesRef(b);
|
||||
return b;
|
||||
}
|
||||
|
||||
public void testCheckIndexIncludesVectors() throws Exception {
|
||||
|
@ -1297,9 +1317,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
if (random().nextInt(4) == 3) {
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> {
|
||||
BytesRef b = randomVector8(dim);
|
||||
fieldValuesCheckSum += b.bytes[b.offset];
|
||||
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
|
||||
byte[] b = randomVector8(dim);
|
||||
fieldValuesCheckSum += b[0];
|
||||
doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
float[] v = randomVector(dim);
|
||||
|
@ -1321,15 +1341,33 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
double checksum = 0;
|
||||
int docCount = 0;
|
||||
long sumDocIds = 0;
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
VectorValues vectors = ctx.reader().getVectorValues("knn_vector");
|
||||
if (vectors != null) {
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
docCount += vectors.size();
|
||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectors.vectorValue()[0];
|
||||
Document doc = storedFields.document(vectors.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
switch (vectorEncoding) {
|
||||
case BYTE -> {
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues("knn_vector");
|
||||
if (byteVectorValues != null) {
|
||||
docCount += byteVectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case FLOAT32 -> {
|
||||
for (LeafReaderContext ctx : r.leaves()) {
|
||||
VectorValues vectorValues = ctx.reader().getVectorValues("knn_vector");
|
||||
if (vectorValues != null) {
|
||||
docCount += vectorValues.size();
|
||||
StoredFields storedFields = ctx.reader().storedFields();
|
||||
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
checksum += vectorValues.vectorValue()[0];
|
||||
Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
|
||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
|||
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||
import org.apache.lucene.codecs.TermVectorsReader;
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DocValuesType;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
|
@ -222,6 +223,11 @@ class MergeReaderWrapper extends LeafReader {
|
|||
return in.getVectorValues(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
|
||||
return in.getByteVectorValues(fieldName);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||
|
|
|
@ -24,6 +24,7 @@ import java.io.IOException;
|
|||
import java.util.List;
|
||||
import java.util.Random;
|
||||
import org.apache.lucene.index.BinaryDocValues;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafMetaData;
|
||||
|
@ -229,6 +230,11 @@ public class QueryUtils {
|
|||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs searchNearestVectors(
|
||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||
|
|
Loading…
Reference in New Issue