Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064)

This commit is contained in:
Benjamin Trent 2023-01-11 03:20:47 -05:00 committed by GitHub
parent e14327288e
commit cc29102a24
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
63 changed files with 2656 additions and 898 deletions

View File

@ -48,7 +48,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch; private final Lucene90NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues; private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random; private final SplittableRandom random;
private final Lucene90BoundsChecker bound; private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw; final Lucene90OnHeapHnswGraph hnsw;
@ -57,7 +57,7 @@ public final class Lucene90HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without // we need two sources of vectors in order to perform diversity check comparisons without
// colliding // colliding
private final RandomAccessVectorValues buildVectors; private final RandomAccessVectorValues<float[]> buildVectors;
/** /**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@ -72,7 +72,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction. * to ensure repeatable construction.
*/ */
public Lucene90HnswGraphBuilder( public Lucene90HnswGraphBuilder(
RandomAccessVectorValues vectors, RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
int maxConn, int maxConn,
int beamWidth, int beamWidth,
@ -103,7 +103,8 @@ public final class Lucene90HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors * accessor for the vectors
*/ */
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
throws IOException {
if (vectors == vectorValues) { if (vectors == vectorValues) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -229,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
float[] candidate, float[] candidate,
float score, float score,
Lucene90NeighborArray neighbors, Lucene90NeighborArray neighbors,
RandomAccessVectorValues vectorValues) RandomAccessVectorValues<float[]> vectorValues)
throws IOException { throws IOException {
bound.set(score); bound.set(score);
for (int i = 0; i < neighbors.size(); i++) { for (int i = 0; i < neighbors.size(); i++) {

View File

@ -27,6 +27,7 @@ import java.util.Map;
import java.util.SplittableRandom; import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -232,6 +233,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
return getOffHeapVectorValues(fieldEntry); return getOffHeapVectorValues(fieldEntry);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {
@ -352,7 +358,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
} }
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues { static class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
final int dimension; final int dimension;
final int[] ordToDoc; final int[] ordToDoc;
@ -433,7 +440,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public RandomAccessVectorValues copy() { public RandomAccessVectorValues<float[]> copy() {
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone()); return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
} }
@ -443,17 +450,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
dataIn.readFloats(value, 0, value.length); dataIn.readFloats(value, 0, value.length);
return value; return value;
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
dataIn.seek((long) targetOrd * byteSize);
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
} }
/** Read the nearest-neighbors graph from the index input */ /** Read the nearest-neighbors graph from the index input */

View File

@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
float[] query, float[] query,
int topK, int topK,
int numSeed, int numSeed,
RandomAccessVectorValues vectors, RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
HnswGraph graphValues, HnswGraph graphValues,
Bits acceptOrds, Bits acceptOrds,

View File

@ -27,6 +27,7 @@ import java.util.Map;
import java.util.function.IntUnaryOperator; import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -224,6 +225,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
return getOffHeapVectorValues(fieldEntry); return getOffHeapVectorValues(fieldEntry);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {
@ -398,7 +404,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
} }
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues { static class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
private final int dimension; private final int dimension;
private final int size; private final int size;
@ -486,7 +493,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public RandomAccessVectorValues copy() { public RandomAccessVectorValues<float[]> copy() {
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone()); return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
} }
@ -496,17 +503,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
dataIn.readFloats(value, 0, value.length); dataIn.readFloats(value, 0, value.length);
return value; return value;
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
dataIn.seek((long) targetOrd * byteSize);
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
} }
/** Read the nearest-neighbors graph from the index input */ /** Read the nearest-neighbors graph from the index input */

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -219,6 +220,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
return OffHeapVectorValues.load(fieldEntry, vectorData); return OffHeapVectorValues.load(fieldEntry, vectorData);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {

View File

@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues { abstract class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
protected final int dimension; protected final int dimension;
protected final int size; protected final int size;
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value; return value;
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public abstract int ordToDoc(int ord); public abstract int ordToDoc(int ord);
static OffHeapVectorValues load( static OffHeapVectorValues load(
@ -137,7 +127,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone()); return new DenseOffHeapVectorValues(dimension, size, slice.clone());
} }
@ -210,7 +200,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone()); return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
} }
@ -282,7 +272,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -291,11 +281,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public int ordToDoc(int ord) { public int ordToDoc(int ord) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -233,12 +234,31 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); FieldEntry fieldEntry = fields.get(field);
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData); if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) { throw new IllegalArgumentException(
return new ExpandingVectorValues(values); "field=\""
} else { + field
return values; + "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
} }
return OffHeapVectorValues.load(fieldEntry, vectorData);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
} }
@Override @Override
@ -292,7 +312,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
// bound k by total number of vectors to prevent oversizing data structures // bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size()); k = Math.min(k, fieldEntry.size());
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData); OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
NeighborQueue results = NeighborQueue results =
HnswGraphSearcher.search( HnswGraphSearcher.search(

View File

@ -0,0 +1,283 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene94;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final BytesRef binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public abstract int ordToDoc(int ord);
static OffHeapByteVectorValues load(
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize = fieldEntry.dimension;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
}
}
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
}
@Override
public BytesRef vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@Override
public int ordToDoc(int ord) {
return ord;
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
}
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
private final DirectMonotonicReader ordToDoc;
private final IndexedDISI disi;
// dataIn was used to init a new IndexedDIS for #randomAccess()
private final IndexInput dataIn;
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
public SparseOffHeapVectorValues(
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
int byteSize)
throws IOException {
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
this.dataIn = dataIn;
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
this.disi =
new IndexedDISI(
dataIn,
fieldEntry.docsWithFieldOffset,
fieldEntry.docsWithFieldLength,
fieldEntry.jumpTableEntryCount,
fieldEntry.denseRankPower,
fieldEntry.size);
}
@Override
public BytesRef vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@Override
public int ordToDoc(int ord) {
return (int) ordToDoc.get(ord);
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) {
return null;
}
return new Bits() {
@Override
public boolean get(int index) {
return acceptDocs.get(ordToDoc(index));
}
@Override
public int length() {
return size;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
}
@Override
public int size() {
return 0;
}
@Override
public BytesRef vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
}
}

View File

@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues { abstract class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
protected final int dimension; protected final int dimension;
protected final int size; protected final int size;
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value; return value;
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public abstract int ordToDoc(int ord); public abstract int ordToDoc(int ord);
static OffHeapVectorValues load( static OffHeapVectorValues load(
@ -143,7 +133,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
} }
@ -219,7 +209,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
} }
@ -291,7 +281,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -300,11 +290,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public int ordToDoc(int ord) { public int ordToDoc(int ord) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private void writeGraph( private void writeGraph(
IndexOutput graphData, IndexOutput graphData,
RandomAccessVectorValues vectorValues, RandomAccessVectorValues<float[]> vectorValues,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
long graphDataOffset, long graphDataOffset,
long[] offsets, long[] offsets,

View File

@ -53,7 +53,7 @@ public final class Lucene91HnswGraphBuilder {
private final Lucene91NeighborArray scratch; private final Lucene91NeighborArray scratch;
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues vectorValues; private final RandomAccessVectorValues<float[]> vectorValues;
private final SplittableRandom random; private final SplittableRandom random;
private final Lucene91BoundsChecker bound; private final Lucene91BoundsChecker bound;
private final HnswGraphSearcher<float[]> graphSearcher; private final HnswGraphSearcher<float[]> graphSearcher;
@ -64,7 +64,7 @@ public final class Lucene91HnswGraphBuilder {
// we need two sources of vectors in order to perform diversity check comparisons without // we need two sources of vectors in order to perform diversity check comparisons without
// colliding // colliding
private RandomAccessVectorValues buildVectors; private RandomAccessVectorValues<float[]> buildVectors;
/** /**
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense * Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
@ -79,7 +79,7 @@ public final class Lucene91HnswGraphBuilder {
* to ensure repeatable construction. * to ensure repeatable construction.
*/ */
public Lucene91HnswGraphBuilder( public Lucene91HnswGraphBuilder(
RandomAccessVectorValues vectors, RandomAccessVectorValues<float[]> vectors,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
int maxConn, int maxConn,
int beamWidth, int beamWidth,
@ -119,7 +119,8 @@ public final class Lucene91HnswGraphBuilder {
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors * accessor for the vectors
*/ */
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException { public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
throws IOException {
if (vectors == vectorValues) { if (vectors == vectorValues) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -250,7 +251,7 @@ public final class Lucene91HnswGraphBuilder {
float[] candidate, float[] candidate,
float score, float score,
Lucene91NeighborArray neighbors, Lucene91NeighborArray neighbors,
RandomAccessVectorValues vectorValues) RandomAccessVectorValues<float[]> vectorValues)
throws IOException { throws IOException {
bound.set(score); bound.set(score);
for (int i = 0; i < neighbors.size(); i++) { for (int i = 0; i < neighbors.size(); i++) {

View File

@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
} }
private Lucene91OnHeapHnswGraph writeGraph( private Lucene91OnHeapHnswGraph writeGraph(
RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction) RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException { throws IOException {
// build graph // build graph

View File

@ -268,13 +268,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
} }
private OnHeapHnswGraph writeGraph( private OnHeapHnswGraph writeGraph(
RandomAccessVectorValues vectorValues, RandomAccessVectorValues<float[]> vectorValues,
VectorEncoding vectorEncoding, VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction) VectorSimilarityFunction similarityFunction)
throws IOException { throws IOException {
// build graph // build graph
HnswGraphBuilder<?> hnswGraphBuilder = HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectorValues, vectorValues,
vectorEncoding, vectorEncoding,

View File

@ -30,12 +30,14 @@ import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
@ -379,8 +381,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
IndexOutput tempVectorData = IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput( segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context); vectorData.getName(), "temp", segmentWriteState.context);
@ -389,7 +389,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
try { try {
// write the vector data to a temporary file // write the vector data to a temporary file
DocsWithFieldSet docsWithField = DocsWithFieldSet docsWithField =
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize); switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
};
CodecUtil.writeFooter(tempVectorData); CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData); IOUtils.close(tempVectorData);
@ -405,23 +410,49 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction // we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds // doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator? // TODO: separate random access vector values from DocIdSetIterator?
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize; int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
OffHeapVectorValues offHeapVectors =
new OffHeapVectorValues.DenseOffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
OnHeapHnswGraph graph = null; OnHeapHnswGraph graph = null;
if (offHeapVectors.size() != 0) { if (docsWithField.cardinality() != 0) {
// build graph // build graph
HnswGraphBuilder<?> hnswGraphBuilder = graph =
HnswGraphBuilder.create( switch (fieldInfo.getVectorEncoding()) {
offHeapVectors, case BYTE -> {
fieldInfo.getVectorEncoding(), OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
fieldInfo.getVectorSimilarityFunction(), new OffHeapByteVectorValues.DenseOffHeapVectorValues(
M, fieldInfo.getVectorDimension(),
beamWidth, docsWithField.cardinality(),
HnswGraphBuilder.randSeed); vectorDataInput,
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); byteSize);
graph = hnswGraphBuilder.build(offHeapVectors.copy()); HnswGraphBuilder<BytesRef> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
case FLOAT32 -> {
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
new OffHeapVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
};
writeGraph(graph); writeGraph(graph);
} }
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -554,16 +585,37 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
/**
* Writes the byte vector values to the output and returns a set of documents that contains
* vectors.
*/
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
BytesRef binaryValue = byteVectorValues.binaryValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
}
/** /**
* Writes the vector values to the output and returns a set of documents that contains vectors. * Writes the vector values to the output and returns a set of documents that contains vectors.
*/ */
private static DocsWithFieldSet writeVectorData( private static DocsWithFieldSet writeVectorData(
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException { IndexOutput output, VectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet(); DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
// write vector // write vector
BytesRef binaryValue = vectors.binaryValue(); BytesRef binaryValue = floatVectorValues.binaryValue();
assert binaryValue.length == vectors.dimension() * scalarSize; assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV); docsWithField.add(docV);
} }
@ -580,7 +632,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
private final int dim; private final int dim;
private final DocsWithFieldSet docsWithField; private final DocsWithFieldSet docsWithField;
private final List<T> vectors; private final List<T> vectors;
private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder<T> hnswGraphBuilder; private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1; private int lastDocID = -1;
@ -593,8 +644,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) { case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
@Override @Override
public BytesRef copyValue(BytesRef value) { public BytesRef copyValue(BytesRef value) {
return new BytesRef( return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
} }
}; };
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) { case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
@ -613,16 +663,15 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
this.dim = fieldInfo.getVectorDimension(); this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet(); this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>(); vectors = new ArrayList<>();
raVectorValues = new RAVectorValues<>(vectors, dim); RAVectorValues<T> raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder = hnswGraphBuilder =
(HnswGraphBuilder<T>) HnswGraphBuilder.create(
HnswGraphBuilder.create( raVectorValues,
raVectorValues, fieldInfo.getVectorEncoding(),
fieldInfo.getVectorEncoding(), fieldInfo.getVectorSimilarityFunction(),
fieldInfo.getVectorSimilarityFunction(), M,
M, beamWidth,
beamWidth, HnswGraphBuilder.randSeed);
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream); hnswGraphBuilder.setInfoStream(infoStream);
} }
@ -667,7 +716,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
private static class RAVectorValues<T> implements RandomAccessVectorValues { private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors; private final List<T> vectors;
private final int dim; private final int dim;
@ -687,17 +736,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
} }
@Override @Override
public float[] vectorValue(int targetOrd) throws IOException { public T vectorValue(int targetOrd) throws IOException {
return (float[]) vectors.get(targetOrd); return vectors.get(targetOrd);
} }
@Override @Override
public BytesRef binaryValue(int targetOrd) throws IOException { public RAVectorValues<T> copy() throws IOException {
return (BytesRef) vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues copy() throws IOException {
return this; return this;
} }
} }

View File

@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.IndexFileNames;
@ -143,6 +144,39 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding()); return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldInfo info = readState.fieldInfos.fieldInfo(field);
if (info == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
return null;
}
int dimension = info.getVectorDimension();
if (dimension == 0) {
throw new IllegalStateException(
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
}
FieldEntry fieldEntry = fieldEntries.get(field);
if (fieldEntry == null) {
// mirror the handling in Lucene90VectorReader#getVectorValues
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
return null;
}
if (dimension != fieldEntry.dimension) {
throw new IllegalStateException(
"Inconsistent vector dimension for field=\""
+ field
+ "\"; "
+ dimension
+ " != "
+ fieldEntry.dimension);
}
IndexInput bytesSlice =
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {
@ -187,7 +221,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override @Override
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {
VectorValues values = getVectorValues(field); ByteVectorValues values = getByteVectorValues(field);
if (target.length != values.dimension()) { if (target.length != values.dimension()) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"vector query dimension: " "vector query dimension: "
@ -213,7 +247,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
break; break;
} }
BytesRef vector = values.binaryValue(); BytesRef vector = values.vectorValue();
float score = vectorSimilarity.compare(vector, target); float score = vectorSimilarity.compare(vector, target);
topK.insertWithOverflow(new ScoreDoc(doc, score)); topK.insertWithOverflow(new ScoreDoc(doc, score));
numVisited++; numVisited++;
@ -301,7 +335,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
} }
private static class SimpleTextVectorValues extends VectorValues private static class SimpleTextVectorValues extends VectorValues
implements RandomAccessVectorValues { implements RandomAccessVectorValues<float[]> {
private final BytesRefBuilder scratch = new BytesRefBuilder(); private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry; private final FieldEntry entry;
@ -356,7 +390,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
} }
@Override @Override
public RandomAccessVectorValues copy() { public RandomAccessVectorValues<float[]> copy() {
return this; return this;
} }
@ -409,10 +443,99 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
public float[] vectorValue(int targetOrd) throws IOException { public float[] vectorValue(int targetOrd) throws IOException {
return values[targetOrd]; return values[targetOrd];
} }
}
private static class SimpleTextByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry;
private final IndexInput in;
private final BytesRef binaryValue;
private final byte[][] values;
int curOrd;
SimpleTextByteVectorValues(FieldEntry entry, IndexInput in) throws IOException {
this.entry = entry;
this.in = in;
values = new byte[entry.size()][entry.dimension];
binaryValue = new BytesRef(entry.dimension);
binaryValue.length = binaryValue.bytes.length;
curOrd = -1;
readAllVectors();
}
@Override @Override
public BytesRef binaryValue(int targetOrd) throws IOException { public int dimension() {
throw new UnsupportedOperationException(); return entry.dimension;
}
@Override
public int size() {
return entry.size();
}
@Override
public BytesRef vectorValue() {
binaryValue.bytes = values[curOrd];
return binaryValue;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() {
return this;
}
@Override
public int docID() {
if (curOrd == -1) {
return -1;
} else if (curOrd >= entry.size()) {
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
// immediately afterward should also return NO_MORE_DOCS
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
return NO_MORE_DOCS;
}
return entry.ordToDoc[curOrd];
}
@Override
public int nextDoc() throws IOException {
if (++curOrd < entry.size()) {
return docID();
}
return NO_MORE_DOCS;
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
private void readAllVectors() throws IOException {
for (byte[] value : values) {
readVector(value);
}
}
private void readVector(byte[] value) throws IOException {
SimpleTextUtil.readLine(in, scratch);
// skip leading "[" and strip trailing "]"
String s = new BytesRef(scratch.bytes(), 1, scratch.length() - 2).utf8ToString();
String[] floatStrings = s.split(",");
assert floatStrings.length == value.length
: " read " + s + " when expecting " + value.length + " floats";
for (int i = 0; i < floatStrings.length; i++) {
value[i] = (byte) Float.parseFloat(floatStrings[i]);
}
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
binaryValue.bytes = values[curOrd];
return binaryValue;
} }
} }

View File

@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
import java.nio.ByteOrder; import java.nio.ByteOrder;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
@ -85,6 +86,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
: vectorValues; : vectorValues;
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs search( public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
@ -202,6 +208,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState); return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) {
throw new UnsupportedOperationException();
}
@Override @Override
public void checkIntegrity() {} public void checkIntegrity() {}
}; };
@ -228,7 +239,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
} }
@Override @Override
public void addValue(int docID, Object value) { public void addValue(int docID, float[] value) {
if (docID == lastDocID) { if (docID == lastDocID) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"VectorValuesField \"" "VectorValuesField \""
@ -236,25 +247,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
+ "\" appears more than once in this document (only one value is allowed per field)"); + "\" appears more than once in this document (only one value is allowed per field)");
} }
assert docID > lastDocID; assert docID > lastDocID;
float[] vectorValue =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> (float[]) value;
case BYTE -> bytesToFloats((BytesRef) value);
};
docsWithField.add(docID); docsWithField.add(docID);
vectors.add(copyValue(vectorValue)); vectors.add(copyValue(value));
lastDocID = docID; lastDocID = docID;
} }
private float[] bytesToFloats(BytesRef b) {
// This is used only by SimpleTextKnnVectorsWriter
float[] floats = new float[dim];
for (int i = 0; i < dim; i++) {
floats[i] = b.bytes[i + b.offset];
}
return floats;
}
@Override @Override
public float[] copyValue(float[] vectorValue) { public float[] copyValue(float[] vectorValue) {
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim); return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);

View File

@ -34,7 +34,7 @@ public abstract class KnnFieldVectorsWriter<T> implements Accountable {
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in * Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
* increasing order. * increasing order.
*/ */
public abstract void addValue(int docID, Object vectorValue) throws IOException; public abstract void addValue(int docID, T vectorValue) throws IOException;
/** /**
* Used to copy values being indexed to internal storage. * Used to copy values being indexed to internal storage.

View File

@ -18,6 +18,7 @@
package org.apache.lucene.codecs; package org.apache.lucene.codecs;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
@ -98,6 +99,11 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public ByteVectorValues getByteVectorValues(String field) {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs search( public TopDocs search(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {

View File

@ -19,6 +19,7 @@ package org.apache.lucene.codecs;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
@ -51,6 +52,13 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
*/ */
public abstract VectorValues getVectorValues(String field) throws IOException; public abstract VectorValues getVectorValues(String field) throws IOException;
/**
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
* never {@code null}.
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
/** /**
* Return the k nearest neighbor documents as determined by comparison of their vector values for * Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document * this field, to the given vector, by the field's similarity function. The score of each document

View File

@ -21,10 +21,12 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
@ -44,13 +46,29 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
/** Write field for merging */ /** Write field for merging */
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo); switch (fieldInfo.getVectorEncoding()) {
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState); case BYTE:
for (int doc = mergedValues.nextDoc(); KnnFieldVectorsWriter<BytesRef> byteWriter =
doc != DocIdSetIterator.NO_MORE_DOCS; (KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
doc = mergedValues.nextDoc()) { ByteVectorValues mergedBytes =
writer.addValue(doc, mergedValues.vectorValue()); MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
for (int doc = mergedBytes.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = mergedBytes.nextDoc()) {
byteWriter.addValue(doc, mergedBytes.vectorValue());
}
break;
case FLOAT32:
KnnFieldVectorsWriter<float[]> floatWriter =
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
VectorValues mergedFloats = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
for (int doc = mergedFloats.nextDoc();
doc != DocIdSetIterator.NO_MORE_DOCS;
doc = mergedFloats.nextDoc()) {
floatWriter.addValue(doc, mergedFloats.vectorValue());
}
break;
} }
} }
@ -104,20 +122,34 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
} }
} }
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */ private static class ByteVectorValuesSub extends DocIDMerger.Sub {
protected static class MergedVectorValues extends VectorValues {
private final List<VectorValuesSub> subs;
private final DocIDMerger<VectorValuesSub> docIdMerger;
private final int size;
private int docId; final ByteVectorValues values;
private VectorValuesSub current;
ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
}
}
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
protected static final class MergedVectorValues {
private MergedVectorValues() {}
/** Returns a merged view over all the segment's {@link VectorValues}. */ /** Returns a merged view over all the segment's {@link VectorValues}. */
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState) public static VectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException { throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues(); assert fieldInfo != null && fieldInfo.hasVectorValues();
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new UnsupportedOperationException(
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
}
List<VectorValuesSub> subs = new ArrayList<>(); List<VectorValuesSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) { for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i]; KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
@ -128,60 +160,147 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
} }
} }
} }
return new MergedVectorValues(subs, mergeState); return new MergedFloat32VectorValues(subs, mergeState);
} }
private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState) /** Returns a merged view over all the segment's {@link ByteVectorValues}. */
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
throws IOException { throws IOException {
this.subs = subs; assert fieldInfo != null && fieldInfo.hasVectorValues();
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
int totalSize = 0; throw new UnsupportedOperationException(
for (VectorValuesSub sub : subs) { "Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
totalSize += sub.values.size();
} }
size = totalSize; List<ByteVectorValuesSub> subs = new ArrayList<>();
docId = -1; for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
} KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
if (knnVectorsReader != null) {
@Override ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
public int docID() { if (values != null) {
return docId; subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
} }
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
} }
return docId; return new MergedByteVectorValues(subs, mergeState);
} }
@Override static class MergedFloat32VectorValues extends VectorValues {
public float[] vectorValue() throws IOException { private final List<VectorValuesSub> subs;
return current.values.vectorValue(); private final DocIDMerger<VectorValuesSub> docIdMerger;
private final int size;
private int docId;
VectorValuesSub current;
private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (VectorValuesSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
}
return docId;
}
@Override
public float[] vectorValue() throws IOException {
return current.values.vectorValue();
}
@Override
public BytesRef binaryValue() throws IOException {
return current.values.binaryValue();
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return subs.get(0).values.dimension();
}
} }
@Override static class MergedByteVectorValues extends ByteVectorValues {
public BytesRef binaryValue() throws IOException { private final List<ByteVectorValuesSub> subs;
return current.values.binaryValue(); private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
} private final int size;
@Override private int docId;
public int advance(int target) { ByteVectorValuesSub current;
throw new UnsupportedOperationException();
}
@Override private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
public int size() { throws IOException {
return size; this.subs = subs;
} docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (ByteVectorValuesSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override @Override
public int dimension() { public BytesRef vectorValue() throws IOException {
return subs.get(0).values.dimension(); return current.values.vectorValue();
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
}
return docId;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return subs.get(0).values.dimension();
}
} }
} }
} }

View File

@ -1,47 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import org.apache.lucene.index.FilterVectorValues;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef;
/** reads from byte-encoded data */
class ExpandingVectorValues extends FilterVectorValues {
private final float[] value;
/**
* @param in the wrapped values
*/
protected ExpandingVectorValues(VectorValues in) {
super(in);
value = new float[in.dimension()];
}
@Override
public float[] vectorValue() throws IOException {
BytesRef binaryValue = binaryValue();
byte[] bytes = binaryValue.bytes;
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
value[i] = bytes[j];
}
return value;
}
}

View File

@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -238,12 +239,31 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field); FieldEntry fieldEntry = fields.get(field);
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData); if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) { throw new IllegalArgumentException(
return new ExpandingVectorValues(values); "field=\""
} else { + field
return values; + "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
} }
return OffHeapVectorValues.load(fieldEntry, vectorData);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
} }
@Override @Override
@ -303,7 +323,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
// bound k by total number of vectors to prevent oversizing data structures // bound k by total number of vectors to prevent oversizing data structures
k = Math.min(k, fieldEntry.size()); k = Math.min(k, fieldEntry.size());
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData); OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
NeighborQueue results = NeighborQueue results =
HnswGraphSearcher.search( HnswGraphSearcher.search(

View File

@ -391,17 +391,21 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
@Override @Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
IndexOutput tempVectorData = IndexOutput tempVectorData =
segmentWriteState.directory.createTempOutput( segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context); vectorData.getName(), "temp", segmentWriteState.context);
IndexInput vectorDataInput = null; IndexInput vectorDataInput = null;
boolean success = false; boolean success = false;
try { try {
// write the vector data to a temporary file
// write the vector data to a temporary file // write the vector data to a temporary file
DocsWithFieldSet docsWithField = DocsWithFieldSet docsWithField =
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize); switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
};
CodecUtil.writeFooter(tempVectorData); CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData); IOUtils.close(tempVectorData);
@ -417,24 +421,50 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
// we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction // we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds // doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator? // TODO: separate random access vector values from DocIdSetIterator?
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize; int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
OffHeapVectorValues offHeapVectors =
new OffHeapVectorValues.DenseOffHeapVectorValues(
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
OnHeapHnswGraph graph = null; OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null; int[][] vectorIndexNodeOffsets = null;
if (offHeapVectors.size() != 0) { if (docsWithField.cardinality() != 0) {
// build graph // build graph
HnswGraphBuilder<?> hnswGraphBuilder = graph =
HnswGraphBuilder.create( switch (fieldInfo.getVectorEncoding()) {
offHeapVectors, case BYTE -> {
fieldInfo.getVectorEncoding(), OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
fieldInfo.getVectorSimilarityFunction(), new OffHeapByteVectorValues.DenseOffHeapVectorValues(
M, fieldInfo.getVectorDimension(),
beamWidth, docsWithField.cardinality(),
HnswGraphBuilder.randSeed); vectorDataInput,
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream); byteSize);
graph = hnswGraphBuilder.build(offHeapVectors.copy()); HnswGraphBuilder<BytesRef> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
case FLOAT32 -> {
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
new OffHeapVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize);
HnswGraphBuilder<float[]> hnswGraphBuilder =
HnswGraphBuilder.create(
vectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
yield hnswGraphBuilder.build(vectorValues.copy());
}
};
vectorIndexNodeOffsets = writeGraph(graph); vectorIndexNodeOffsets = writeGraph(graph);
} }
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -605,16 +635,37 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
/**
* Writes the byte vector values to the output and returns a set of documents that contains
* vectors.
*/
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
BytesRef binaryValue = byteVectorValues.binaryValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
}
/** /**
* Writes the vector values to the output and returns a set of documents that contains vectors. * Writes the vector values to the output and returns a set of documents that contains vectors.
*/ */
private static DocsWithFieldSet writeVectorData( private static DocsWithFieldSet writeVectorData(
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException { IndexOutput output, VectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet(); DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
// write vector // write vector
BytesRef binaryValue = vectors.binaryValue(); BytesRef binaryValue = floatVectorValues.binaryValue();
assert binaryValue.length == vectors.dimension() * scalarSize; assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length); output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
docsWithField.add(docV); docsWithField.add(docV);
} }
@ -631,7 +682,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
private final int dim; private final int dim;
private final DocsWithFieldSet docsWithField; private final DocsWithFieldSet docsWithField;
private final List<T> vectors; private final List<T> vectors;
private final RAVectorValues<T> raVectorValues;
private final HnswGraphBuilder<T> hnswGraphBuilder; private final HnswGraphBuilder<T> hnswGraphBuilder;
private int lastDocID = -1; private int lastDocID = -1;
@ -657,36 +707,31 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
}; };
} }
@SuppressWarnings("unchecked")
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
throws IOException { throws IOException {
this.fieldInfo = fieldInfo; this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension(); this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet(); this.docsWithField = new DocsWithFieldSet();
vectors = new ArrayList<>(); vectors = new ArrayList<>();
raVectorValues = new RAVectorValues<>(vectors, dim);
hnswGraphBuilder = hnswGraphBuilder =
(HnswGraphBuilder<T>) HnswGraphBuilder.create(
HnswGraphBuilder.create( new RAVectorValues<>(vectors, dim),
raVectorValues, fieldInfo.getVectorEncoding(),
fieldInfo.getVectorEncoding(), fieldInfo.getVectorSimilarityFunction(),
fieldInfo.getVectorSimilarityFunction(), M,
M, beamWidth,
beamWidth, HnswGraphBuilder.randSeed);
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream); hnswGraphBuilder.setInfoStream(infoStream);
} }
@Override @Override
@SuppressWarnings("unchecked") public void addValue(int docID, T vectorValue) throws IOException {
public void addValue(int docID, Object value) throws IOException {
if (docID == lastDocID) { if (docID == lastDocID) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"VectorValuesField \"" "VectorValuesField \""
+ fieldInfo.name + fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)"); + "\" appears more than once in this document (only one value is allowed per field)");
} }
T vectorValue = (T) value;
assert docID > lastDocID; assert docID > lastDocID;
docsWithField.add(docID); docsWithField.add(docID);
vectors.add(copyValue(vectorValue)); vectors.add(copyValue(vectorValue));
@ -719,7 +764,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
} }
} }
private static class RAVectorValues<T> implements RandomAccessVectorValues { private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors; private final List<T> vectors;
private final int dim; private final int dim;
@ -739,17 +784,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
} }
@Override @Override
public float[] vectorValue(int targetOrd) throws IOException { public T vectorValue(int targetOrd) throws IOException {
return (float[]) vectors.get(targetOrd); return vectors.get(targetOrd);
} }
@Override @Override
public BytesRef binaryValue(int targetOrd) throws IOException { public RandomAccessVectorValues<T> copy() throws IOException {
return (BytesRef) vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues copy() throws IOException {
return this; return this;
} }
} }

View File

@ -0,0 +1,283 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene95;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final BytesRef binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = byteSize;
byteBuffer = ByteBuffer.allocate(byteSize);
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public abstract int ordToDoc(int ord);
static OffHeapByteVectorValues load(
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
}
IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize = fieldEntry.dimension;
if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
} else {
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
}
}
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
super(dimension, size, slice, byteSize);
}
@Override
public BytesRef vectorValue() throws IOException {
slice.seek((long) doc * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
return binaryValue;
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@Override
public int ordToDoc(int ord) {
return ord;
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
}
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
private final DirectMonotonicReader ordToDoc;
private final IndexedDISI disi;
// dataIn was used to init a new IndexedDIS for #randomAccess()
private final IndexInput dataIn;
private final Lucene95HnswVectorsReader.FieldEntry fieldEntry;
public SparseOffHeapVectorValues(
Lucene95HnswVectorsReader.FieldEntry fieldEntry,
IndexInput dataIn,
IndexInput slice,
int byteSize)
throws IOException {
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
this.fieldEntry = fieldEntry;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
this.dataIn = dataIn;
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
this.disi =
new IndexedDISI(
dataIn,
fieldEntry.docsWithFieldOffset,
fieldEntry.docsWithFieldLength,
fieldEntry.jumpTableEntryCount,
fieldEntry.denseRankPower,
fieldEntry.size);
}
@Override
public BytesRef vectorValue() throws IOException {
slice.seek((long) (disi.index()) * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
return binaryValue;
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@Override
public int ordToDoc(int ord) {
return (int) ordToDoc.get(ord);
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) {
return null;
}
return new Bits() {
@Override
public boolean get(int index) {
return acceptDocs.get(ordToDoc(index));
}
@Override
public int length() {
return size;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null, 0);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
}
@Override
public int size() {
return 0;
}
@Override
public BytesRef vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public BytesRef vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
}
}

View File

@ -20,6 +20,7 @@ package org.apache.lucene.codecs.lucene95;
import java.io.IOException; import java.io.IOException;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.store.RandomAccessInput;
@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */ /** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues { abstract class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
protected final int dimension; protected final int dimension;
protected final int size; protected final int size;
@ -66,31 +68,17 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
return value; return value;
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
readValue(targetOrd);
return binaryValue;
}
private void readValue(int targetOrd) throws IOException {
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
}
public abstract int ordToDoc(int ord); public abstract int ordToDoc(int ord);
static OffHeapVectorValues load( static OffHeapVectorValues load(
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException { Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
if (fieldEntry.docsWithFieldOffset == -2) { if (fieldEntry.docsWithFieldOffset == -2
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return new EmptyOffHeapVectorValues(fieldEntry.dimension); return new EmptyOffHeapVectorValues(fieldEntry.dimension);
} }
IndexInput bytesSlice = IndexInput bytesSlice =
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength); vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
int byteSize = int byteSize = fieldEntry.dimension * Float.BYTES;
switch (fieldEntry.vectorEncoding) {
case BYTE -> fieldEntry.dimension;
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
};
if (fieldEntry.docsWithFieldOffset == -1) { if (fieldEntry.docsWithFieldOffset == -1) {
return new DenseOffHeapVectorValues( return new DenseOffHeapVectorValues(
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize); fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
@ -143,7 +131,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize); return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
} }
@ -219,7 +207,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize); return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
} }
@ -291,7 +279,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
} }
@Override @Override
public RandomAccessVectorValues copy() throws IOException { public RandomAccessVectorValues<float[]> copy() throws IOException {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@ -300,11 +288,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public BytesRef binaryValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public int ordToDoc(int ord) { public int ordToDoc(int ord) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -27,6 +27,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentReadState;
@ -255,6 +256,16 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
} }
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
KnnVectorsReader knnVectorsReader = fields.get(field);
if (knnVectorsReader == null) {
return null;
} else {
return knnVectorsReader.getByteVectorValues(field);
}
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {

View File

@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.document;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.BytesRef;
/**
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
* - that is, every dimension of a vector contains an explicit value, stored packed into an array
* (of type byte[]) whose length is the vector dimension. Values can be retrieved using {@link
* ByteVectorValues}, which is a forward-only docID-based iterator and also offers random-access by
* dense ordinal (not docId). {@link VectorSimilarityFunction} may be used to compare vectors at
* query time (for example as part of result ranking). A KnnByteVectorField may be associated with a
* search similarity function defining the metric used for nearest-neighbor search among vectors of
* that field.
*
* @lucene.experimental
*/
public class KnnByteVectorField extends Field {
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
int dimension = v.length;
if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector");
}
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
throw new IllegalArgumentException(
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
}
if (similarityFunction == null) {
throw new IllegalArgumentException("similarity function must not be null");
}
FieldType type = new FieldType();
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
type.freeze();
return type;
}
/**
* Create a new vector query for the provided field targeting the byte vector
*
* @param field The field to query
* @param queryVector The byte vector target
* @param k The number of nearest neighbors to gather
* @return A new vector query
*/
public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
return new KnnByteVectorQuery(field, queryVector, k);
}
/**
* A convenience method for creating a vector field type.
*
* @param dimension dimension of vectors
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or has dimension &gt; 1024.
*/
public static FieldType createFieldType(
int dimension, VectorSimilarityFunction similarityFunction) {
FieldType type = new FieldType();
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
type.freeze();
return type;
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* be constant-length.
*
* @param name field name
* @param vector value
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnByteVectorField(
String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
/**
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
* single-valued: each document has either one value or no value. Vectors of a single field share
* the same dimension and similarity function.
*
* @param name field name
* @param vector value
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnByteVectorField(String name, BytesRef vector) {
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
}
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function.
*
* @param name field name
* @param vector value
* @param fieldType field type
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"Attempt to create a vector for field "
+ name
+ " using byte[] but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
}
/** Return the vector value of this field */
public BytesRef vectorValue() {
return (BytesRef) fieldsData;
}
/**
* Set the vector value of this field
*
* @param value the value to set; must not be null, and length must match the field type
*/
public void setVectorValue(BytesRef value) {
if (value == null) {
throw new IllegalArgumentException("value must not be null");
}
if (value.length != type.vectorDimension()) {
throw new IllegalArgumentException(
"value length " + value.length + " must match field dimension " + type.vectorDimension());
}
fieldsData = value;
}
}

View File

@ -20,7 +20,8 @@ package org.apache.lucene.document;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
/** /**
@ -41,18 +42,7 @@ public class KnnVectorField extends Field {
if (v == null) { if (v == null) {
throw new IllegalArgumentException("vector value must not be null"); throw new IllegalArgumentException("vector value must not be null");
} }
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction); int dimension = v.length;
}
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
if (v == null) {
throw new IllegalArgumentException("vector value must not be null");
}
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
}
private static FieldType createType(
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
if (dimension == 0) { if (dimension == 0) {
throw new IllegalArgumentException("cannot index an empty vector"); throw new IllegalArgumentException("cannot index an empty vector");
} }
@ -64,13 +54,13 @@ public class KnnVectorField extends Field {
throw new IllegalArgumentException("similarity function must not be null"); throw new IllegalArgumentException("similarity function must not be null");
} }
FieldType type = new FieldType(); FieldType type = new FieldType();
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction); type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
type.freeze(); type.freeze();
return type; return type;
} }
/** /**
* A convenience method for creating a vector field type with the default FLOAT32 encoding. * A convenience method for creating a vector field type.
* *
* @param dimension dimension of vectors * @param dimension dimension of vectors
* @param similarityFunction a function defining vector proximity. * @param similarityFunction a function defining vector proximity.
@ -78,23 +68,22 @@ public class KnnVectorField extends Field {
*/ */
public static FieldType createFieldType( public static FieldType createFieldType(
int dimension, VectorSimilarityFunction similarityFunction) { int dimension, VectorSimilarityFunction similarityFunction) {
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction); FieldType type = new FieldType();
type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
type.freeze();
return type;
} }
/** /**
* A convenience method for creating a vector field type. * Create a new vector query for the provided field targeting the float vector
* *
* @param dimension dimension of vectors * @param field The field to query
* @param vectorEncoding the encoding of the scalar values * @param queryVector The float vector target
* @param similarityFunction a function defining vector proximity. * @param k The number of nearest neighbors to gather
* @throws IllegalArgumentException if any parameter is null, or has dimension &gt; 1024. * @return A new vector query
*/ */
public static FieldType createFieldType( public static Query newVectorQuery(String field, float[] queryVector, int k) {
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) { return new KnnVectorQuery(field, queryVector, k);
FieldType type = new FieldType();
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
type.freeze();
return type;
} }
/** /**
@ -114,23 +103,6 @@ public class KnnVectorField extends Field {
fieldsData = vector; fieldsData = vector;
} }
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
* be constant-length.
*
* @param name field name
* @param vector value
* @param similarityFunction a function defining vector proximity.
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
super(name, createType(vector, similarityFunction));
fieldsData = vector;
}
/** /**
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are * Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
* single-valued: each document has either one value or no value. Vectors of a single field share * single-valued: each document has either one value or no value. Vectors of a single field share
@ -167,28 +139,6 @@ public class KnnVectorField extends Field {
fieldsData = vector; fieldsData = vector;
} }
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function.
*
* @param name field name
* @param vector value
* @param fieldType field type
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
* dimension &gt; 1024.
*/
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
super(name, fieldType);
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"Attempt to create a vector for field "
+ name
+ " using BytesRef but the field encoding is "
+ fieldType.vectorEncoding());
}
fieldsData = vector;
}
/** Return the vector value of this field */ /** Return the vector value of this field */
public float[] vectorValue() { public float[] vectorValue() {
return (float[]) fieldsData; return (float[]) fieldsData;

View File

@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BytesRef;
/**
* This class provides access to per-document floating point vector values indexed as {@link
* KnnByteVectorField}.
*
* @lucene.experimental
*/
public abstract class ByteVectorValues extends DocIdSetIterator {
/** The maximum length of a vector */
public static final int MAX_DIMENSIONS = 1024;
/** Sole constructor */
protected ByteVectorValues() {}
/** Return the dimension of the vectors */
public abstract int dimension();
/**
* Return the number of vectors for this field.
*
* @return the number of vectors returned by this iterator
*/
public abstract int size();
@Override
public final long cost() {
return size();
}
/**
* Return the vector value for the current document ID. It is illegal to call this method when the
* iterator is not positioned: before advancing, or after failing to advance. The returned array
* may be shared across calls, re-used, and modified as the iterator advances.
*
* @return the vector value
*/
public abstract BytesRef vectorValue() throws IOException;
/**
* Return the binary encoded vector value for the current document ID. These are the bytes
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
* method when the iterator is not positioned: before advancing, or after failing to advance. The
* returned storage may be shared across calls, re-used and modified as the iterator advances.
*
* @return the binary value
*/
public final BytesRef binaryValue() throws IOException {
return vectorValue();
}
}

View File

@ -34,6 +34,7 @@ import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionException;
@ -2588,62 +2589,37 @@ public final class CheckIndex implements Closeable {
+ "\" has vector values but dimension is " + "\" has vector values but dimension is "
+ dimension); + dimension);
} }
VectorValues values = reader.getVectorValues(fieldInfo.name); if (reader.getVectorValues(fieldInfo.name) == null
if (values == null) { && reader.getByteVectorValues(fieldInfo.name) == null) {
continue; continue;
} }
status.totalKnnVectorFields++; status.totalKnnVectorFields++;
switch (fieldInfo.getVectorEncoding()) {
int docCount = 0; case BYTE:
int everyNdoc = Math.max(values.size() / 64, 1); checkByteVectorValues(
while (values.nextDoc() != NO_MORE_DOCS) { Objects.requireNonNull(reader.getByteVectorValues(fieldInfo.name)),
// search the first maxNumSearches vectors to exercise the graph fieldInfo,
if (values.docID() % everyNdoc == 0) { status,
TopDocs docs = reader);
switch (fieldInfo.getVectorEncoding()) { break;
case FLOAT32 -> reader case FLOAT32:
.getVectorReader() checkFloatVectorValues(
.search( Objects.requireNonNull(reader.getVectorValues(fieldInfo.name)),
fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE); fieldInfo,
case BYTE -> reader status,
.getVectorReader() reader);
.search( break;
fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE); default:
};
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
float[] vectorValue = values.vectorValue();
int valueLength = vectorValue.length;
if (valueLength != dimension) {
throw new CheckIndexException( throw new CheckIndexException(
"Field \"" "Field \""
+ fieldInfo.name + fieldInfo.name
+ "\" has a value whose dimension=" + "\" has unexpected vector encoding: "
+ valueLength + fieldInfo.getVectorEncoding());
+ " not matching the field's dimension="
+ dimension);
}
++docCount;
} }
if (docCount != values.size()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has size="
+ values.size()
+ " but when iterated, returns "
+ docCount
+ " docs with values");
}
status.totalVectorValues += docCount;
} }
} }
} }
msg( msg(
infoStream, infoStream,
String.format( String.format(
@ -2667,6 +2643,96 @@ public final class CheckIndex implements Closeable {
return status; return status;
} }
private static void checkFloatVectorValues(
VectorValues values,
FieldInfo fieldInfo,
CheckIndex.Status.VectorValuesStatus status,
CodecReader codecReader)
throws IOException {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
while (values.nextDoc() != NO_MORE_DOCS) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
TopDocs docs =
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
int valueLength = values.vectorValue().length;
if (valueLength != fieldInfo.getVectorDimension()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has a value whose dimension="
+ valueLength
+ " not matching the field's dimension="
+ fieldInfo.getVectorDimension());
}
++docCount;
}
if (docCount != values.size()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has size="
+ values.size()
+ " but when iterated, returns "
+ docCount
+ " docs with values");
}
status.totalVectorValues += docCount;
}
private static void checkByteVectorValues(
ByteVectorValues values,
FieldInfo fieldInfo,
CheckIndex.Status.VectorValuesStatus status,
CodecReader codecReader)
throws IOException {
int docCount = 0;
int everyNdoc = Math.max(values.size() / 64, 1);
while (values.nextDoc() != NO_MORE_DOCS) {
// search the first maxNumSearches vectors to exercise the graph
if (values.docID() % everyNdoc == 0) {
TopDocs docs =
codecReader
.getVectorReader()
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
if (docs.scoreDocs.length == 0) {
throw new CheckIndexException(
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
}
}
int valueLength = values.vectorValue().length;
if (valueLength != fieldInfo.getVectorDimension()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has a value whose dimension="
+ valueLength
+ " not matching the field's dimension="
+ fieldInfo.getVectorDimension());
}
++docCount;
}
if (docCount != values.size()) {
throw new CheckIndexException(
"Field \""
+ fieldInfo.name
+ "\" has size="
+ values.size()
+ " but when iterated, returns "
+ docCount
+ " docs with values");
}
status.totalVectorValues += docCount;
}
/** /**
* Walks the entire N-dimensional points space, verifying that all points fall within the last * Walks the entire N-dimensional points space, verifying that all points fall within the last
* cell's boundaries. * cell's boundaries.

View File

@ -218,7 +218,9 @@ public abstract class CodecReader extends LeafReader {
public final VectorValues getVectorValues(String field) throws IOException { public final VectorValues getVectorValues(String field) throws IOException {
ensureOpen(); ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field); FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorDimension() == 0) { if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
// Field does not exist or does not index vectors // Field does not exist or does not index vectors
return null; return null;
} }
@ -226,6 +228,20 @@ public abstract class CodecReader extends LeafReader {
return getVectorReader().getVectorValues(field); return getVectorReader().getVectorValues(field);
} }
@Override
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
ensureOpen();
FieldInfo fi = getFieldInfos().fieldInfo(field);
if (fi == null
|| fi.getVectorDimension() == 0
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
// Field does not exist or does not index vectors
return null;
}
return getVectorReader().getByteVectorValues(field);
}
@Override @Override
public final TopDocs searchNearestVectors( public final TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {

View File

@ -53,6 +53,11 @@ abstract class DocValuesLeafReader extends LeafReader {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
@Override
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
throw new UnsupportedOperationException();
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {

View File

@ -323,6 +323,15 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return new ExitableVectorValues(vectorValues); return new ExitableVectorValues(vectorValues);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
final ByteVectorValues vectorValues = in.getByteVectorValues(field);
if (vectorValues == null) {
return null;
}
return new ExitableByteVectorValues(vectorValues);
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
@ -387,17 +396,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
} }
} }
private class ExitableVectorValues extends FilterVectorValues { private class ExitableVectorValues extends VectorValues {
private int docToCheck; private int docToCheck;
private final VectorValues vectorValues;
public ExitableVectorValues(VectorValues vectorValues) { public ExitableVectorValues(VectorValues vectorValues) {
super(vectorValues); this.vectorValues = vectorValues;
docToCheck = 0; docToCheck = 0;
} }
@Override @Override
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
final int advance = super.advance(target); final int advance = vectorValues.advance(target);
if (advance >= docToCheck) { if (advance >= docToCheck) {
checkAndThrow(); checkAndThrow();
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
@ -405,9 +415,14 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return advance; return advance;
} }
@Override
public int docID() {
return vectorValues.docID();
}
@Override @Override
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
final int nextDoc = super.nextDoc(); final int nextDoc = vectorValues.nextDoc();
if (nextDoc >= docToCheck) { if (nextDoc >= docToCheck) {
checkAndThrow(); checkAndThrow();
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
@ -415,14 +430,91 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
return nextDoc; return nextDoc;
} }
@Override
public int dimension() {
return vectorValues.dimension();
}
@Override @Override
public float[] vectorValue() throws IOException { public float[] vectorValue() throws IOException {
return in.vectorValue(); return vectorValues.vectorValue();
}
@Override
public int size() {
return vectorValues.size();
} }
@Override @Override
public BytesRef binaryValue() throws IOException { public BytesRef binaryValue() throws IOException {
return in.binaryValue(); return vectorValues.binaryValue();
}
/**
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
* if {@link Thread#interrupted()} returns true.
*/
private void checkAndThrow() {
if (queryTimeout.shouldExit()) {
throw new ExitingReaderException(
"The request took too long to iterate over vector values. Timeout: "
+ queryTimeout.toString()
+ ", VectorValues="
+ in);
} else if (Thread.interrupted()) {
throw new ExitingReaderException(
"Interrupted while iterating over vector values. VectorValues=" + in);
}
}
}
private class ExitableByteVectorValues extends ByteVectorValues {
private int docToCheck;
private final ByteVectorValues vectorValues;
public ExitableByteVectorValues(ByteVectorValues vectorValues) {
this.vectorValues = vectorValues;
docToCheck = 0;
}
@Override
public int advance(int target) throws IOException {
final int advance = vectorValues.advance(target);
if (advance >= docToCheck) {
checkAndThrow();
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return advance;
}
@Override
public int docID() {
return vectorValues.docID();
}
@Override
public int nextDoc() throws IOException {
final int nextDoc = vectorValues.nextDoc();
if (nextDoc >= docToCheck) {
checkAndThrow();
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
}
return nextDoc;
}
@Override
public int dimension() {
return vectorValues.dimension();
}
@Override
public int size() {
return vectorValues.size();
}
@Override
public BytesRef vectorValue() throws IOException {
return vectorValues.vectorValue();
} }
/** /**

View File

@ -351,6 +351,11 @@ public abstract class FilterLeafReader extends LeafReader {
return in.getVectorValues(field); return in.getVectorValues(field);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return in.getByteVectorValues(field);
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {

View File

@ -38,6 +38,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.PointsFormat; import org.apache.lucene.codecs.PointsFormat;
import org.apache.lucene.codecs.PointsWriter; import org.apache.lucene.codecs.PointsWriter;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
@ -721,11 +722,7 @@ final class IndexingChain implements Accountable {
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue()); pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
} }
if (fieldType.vectorDimension() != 0) { if (fieldType.vectorDimension() != 0) {
switch (fieldType.vectorEncoding()) { indexVectorValue(docID, pf, fieldType.vectorEncoding(), field);
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
docID, ((KnnVectorField) field).vectorValue());
}
} }
return indexedField; return indexedField;
} }
@ -959,6 +956,18 @@ final class IndexingChain implements Accountable {
} }
} }
@SuppressWarnings("unchecked")
private void indexVectorValue(
int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
throws IOException {
switch (vectorEncoding) {
case BYTE -> ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
.addValue(docID, ((KnnByteVectorField) field).vectorValue());
case FLOAT32 -> ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
.addValue(docID, ((KnnVectorField) field).vectorValue());
}
}
/** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */ /** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */
private PerField getPerField(String name) { private PerField getPerField(String name) {
final int hashPos = name.hashCode() & hashMask; final int hashPos = name.hashCode() & hashMask;

View File

@ -208,6 +208,14 @@ public abstract class LeafReader extends IndexReader {
*/ */
public abstract VectorValues getVectorValues(String field) throws IOException; public abstract VectorValues getVectorValues(String field) throws IOException;
/**
* Returns {@link ByteVectorValues} for this field, or null if no {@link ByteVectorValues} were
* indexed. The returned instance should only be used by a single thread.
*
* @lucene.experimental
*/
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
/** /**
* Return the k nearest neighbor documents as determined by comparison of their vector values for * Return the k nearest neighbor documents as determined by comparison of their vector values for
* this field, to the given vector, by the field's similarity function. The score of each document * this field, to the given vector, by the field's similarity function. The score of each document

View File

@ -408,6 +408,13 @@ public class ParallelLeafReader extends LeafReader {
return reader == null ? null : reader.getVectorValues(fieldName); return reader == null ? null : reader.getVectorValues(fieldName);
} }
@Override
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
ensureOpen();
LeafReader reader = fieldToReader.get(fieldName);
return reader == null ? null : reader.getByteVectorValues(fieldName);
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit) String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)

View File

@ -168,6 +168,11 @@ public final class SlowCodecReaderWrapper {
return reader.getVectorValues(field); return reader.getVectorValues(field);
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return reader.getByteVectorValues(field);
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {

View File

@ -222,34 +222,21 @@ public final class SortingCodecReader extends FilterCodecReader {
final FixedBitSet docsWithField; final FixedBitSet docsWithField;
final float[][] vectors; final float[][] vectors;
final ByteBuffer vectorAsBytes; final ByteBuffer vectorAsBytes;
final BytesRef[] binaryVectors;
private int docId = -1; private int docId = -1;
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap, VectorEncoding encoding) SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
throws IOException {
this.size = delegate.size(); this.size = delegate.size();
this.dimension = delegate.dimension(); this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size()); docsWithField = new FixedBitSet(sortMap.size());
if (encoding == VectorEncoding.BYTE) { vectors = new float[sortMap.size()][];
vectors = null; vectorAsBytes =
binaryVectors = new BytesRef[sortMap.size()]; ByteBuffer.allocate(delegate.dimension() * VectorEncoding.FLOAT32.byteSize)
vectorAsBytes = null; .order(ByteOrder.LITTLE_ENDIAN);
} else {
vectors = new float[sortMap.size()][];
binaryVectors = null;
vectorAsBytes =
ByteBuffer.allocate(delegate.dimension() * encoding.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
}
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) { for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc); int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID); docsWithField.set(newDocID);
if (encoding == VectorEncoding.BYTE) { vectors[newDocID] = delegate.vectorValue().clone();
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.binaryValue());
} else {
vectors[newDocID] = delegate.vectorValue().clone();
}
} }
} }
@ -265,12 +252,8 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override @Override
public BytesRef binaryValue() throws IOException { public BytesRef binaryValue() throws IOException {
if (binaryVectors != null) { vectorAsBytes.asFloatBuffer().put(vectors[docId]);
return binaryVectors[docId]; return new BytesRef(vectorAsBytes.array());
} else {
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
return new BytesRef(vectorAsBytes.array());
}
} }
@Override @Override
@ -297,6 +280,60 @@ public final class SortingCodecReader extends FilterCodecReader {
} }
} }
private static class SortingByteVectorValues extends ByteVectorValues {
final int size;
final int dimension;
final FixedBitSet docsWithField;
final BytesRef[] binaryVectors;
private int docId = -1;
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.size = delegate.size();
this.dimension = delegate.dimension();
docsWithField = new FixedBitSet(sortMap.size());
binaryVectors = new BytesRef[sortMap.size()];
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
int newDocID = sortMap.oldToNew(doc);
docsWithField.set(newDocID);
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
}
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
return advance(docId + 1);
}
@Override
public BytesRef vectorValue() throws IOException {
return binaryVectors[docId];
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public int advance(int target) throws IOException {
if (target >= docsWithField.length()) {
return NO_MORE_DOCS;
}
return docId = docsWithField.nextSetBit(target);
}
}
/** /**
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code> * Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
* . If the reader is already sorted, this method might return the reader as-is. * . If the reader is already sorted, this method might return the reader as-is.
@ -465,9 +502,12 @@ public final class SortingCodecReader extends FilterCodecReader {
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
FieldInfo fi = in.getFieldInfos().fieldInfo(field); return new SortingVectorValues(delegate.getVectorValues(field), docMap);
return new SortingVectorValues( }
delegate.getVectorValues(field), docMap, fi.getVectorEncoding());
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return new SortingByteVectorValues(delegate.getByteVectorValues(field), docMap);
} }
@Override @Override

View File

@ -61,8 +61,8 @@ public abstract class VectorValues extends DocIdSetIterator {
/** /**
* Return the binary encoded vector value for the current document ID. These are the bytes * Return the binary encoded vector value for the current document ID. These are the bytes
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this * corresponding to the array return by {@link #vectorValue}. It is illegal to call this method
* method when the iterator is not positioned: before advancing, or after failing to advance. The * when the iterator is not positioned: before advancing, or after failing to advance. The
* returned storage may be shared across calls, re-used and modified as the iterator advances. * returned storage may be shared across calls, re-used and modified as the iterator advances.
* *
* @return the binary value * @return the binary value

View File

@ -31,7 +31,8 @@ import org.apache.lucene.index.Terms;
/** /**
* A {@link Query} that matches documents that contain either a {@link * A {@link Query} that matches documents that contain either a {@link
* org.apache.lucene.document.KnnVectorField}, or a field that indexes norms or doc values. * org.apache.lucene.document.KnnVectorField}, {@link org.apache.lucene.document.KnnByteVectorField}
* or a field that indexes norms or doc values.
*/ */
public class FieldExistsQuery extends Query { public class FieldExistsQuery extends Query {
private String field; private String field;
@ -127,7 +128,12 @@ public class FieldExistsQuery extends Query {
break; break;
} }
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
if (leaf.getVectorValues(field).size() != leaf.maxDoc()) { int numVectors =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> leaf.getVectorValues(field).size();
case BYTE -> leaf.getByteVectorValues(field).size();
};
if (numVectors != leaf.maxDoc()) {
allReadersRewritable = false; allReadersRewritable = false;
break; break;
} }
@ -175,7 +181,11 @@ public class FieldExistsQuery extends Query {
if (fieldInfo.hasNorms()) { // the field indexes norms if (fieldInfo.hasNorms()) { // the field indexes norms
iterator = context.reader().getNormValues(field); iterator = context.reader().getNormValues(field);
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
iterator = context.reader().getVectorValues(field); iterator =
switch (fieldInfo.getVectorEncoding()) {
case FLOAT32 -> context.reader().getVectorValues(field);
case BYTE -> context.reader().getByteVectorValues(field);
};
} else if (fieldInfo.getDocValuesType() } else if (fieldInfo.getDocValuesType()
!= DocValuesType.NONE) { // the field indexes doc values != DocValuesType.NONE) { // the field indexes doc values
switch (fieldInfo.getDocValuesType()) { switch (fieldInfo.getDocValuesType()) {

View File

@ -54,7 +54,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param k the number of documents to find * @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1 * @throws IllegalArgumentException if <code>k</code> is less than 1
*/ */
public KnnByteVectorQuery(String field, byte[] target, int k) { public KnnByteVectorQuery(String field, BytesRef target, int k) {
this(field, target, k, null); this(field, target, k, null);
} }
@ -68,9 +68,9 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
* @param filter a filter applied before the vector search * @param filter a filter applied before the vector search
* @throws IllegalArgumentException if <code>k</code> is less than 1 * @throws IllegalArgumentException if <code>k</code> is less than 1
*/ */
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
super(field, k, filter); super(field, k, filter);
this.target = new BytesRef(target); this.target = target;
} }
@Override @Override

View File

@ -17,6 +17,7 @@
package org.apache.lucene.search; package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
@ -29,7 +30,6 @@ import org.apache.lucene.util.BytesRef;
* search over the vectors. * search over the vectors.
*/ */
abstract class VectorScorer { abstract class VectorScorer {
protected final VectorValues values;
protected final VectorSimilarityFunction similarity; protected final VectorSimilarityFunction similarity;
/** /**
@ -48,53 +48,72 @@ abstract class VectorScorer {
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query) static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
throws IOException { throws IOException {
VectorValues values = context.reader().getVectorValues(fi.name); ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction(); VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
return new ByteVectorScorer(values, query, similarity); return new ByteVectorScorer(values, query, similarity);
} }
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) { VectorScorer(VectorSimilarityFunction similarity) {
this.values = values;
this.similarity = similarity; this.similarity = similarity;
} }
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
}
/** Compute the similarity score for the current document. */ /** Compute the similarity score for the current document. */
abstract float score() throws IOException; abstract float score() throws IOException;
abstract boolean advanceExact(int doc) throws IOException;
private static class ByteVectorScorer extends VectorScorer { private static class ByteVectorScorer extends VectorScorer {
private final BytesRef query; private final BytesRef query;
private final ByteVectorValues values;
protected ByteVectorScorer( protected ByteVectorScorer(
VectorValues values, BytesRef query, VectorSimilarityFunction similarity) { ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
super(values, similarity); super(similarity);
this.values = values;
this.query = query; this.query = query;
} }
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
@Override
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
}
@Override @Override
public float score() throws IOException { public float score() throws IOException {
return similarity.compare(query, values.binaryValue()); return similarity.compare(query, values.vectorValue());
} }
} }
private static class FloatVectorScorer extends VectorScorer { private static class FloatVectorScorer extends VectorScorer {
private final float[] query; private final float[] query;
private final VectorValues values;
protected FloatVectorScorer( protected FloatVectorScorer(
VectorValues values, float[] query, VectorSimilarityFunction similarity) { VectorValues values, float[] query, VectorSimilarityFunction similarity) {
super(values, similarity); super(similarity);
this.query = query; this.query = query;
this.values = values;
}
/**
* Advance the instance to the given document ID and return true if there is a value for that
* document.
*/
@Override
public boolean advanceExact(int doc) throws IOException {
int vectorDoc = values.docID();
if (vectorDoc < doc) {
vectorDoc = values.advance(doc);
}
return vectorDoc == doc;
} }
@Override @Override

View File

@ -53,7 +53,7 @@ public final class HnswGraphBuilder<T> {
private final VectorSimilarityFunction similarityFunction; private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding; private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues vectors; private final RandomAccessVectorValues<T> vectors;
private final SplittableRandom random; private final SplittableRandom random;
private final HnswGraphSearcher<T> graphSearcher; private final HnswGraphSearcher<T> graphSearcher;
@ -63,10 +63,10 @@ public final class HnswGraphBuilder<T> {
// we need two sources of vectors in order to perform diversity check comparisons without // we need two sources of vectors in order to perform diversity check comparisons without
// colliding // colliding
private final RandomAccessVectorValues vectorsCopy; private final RandomAccessVectorValues<T> vectorsCopy;
public static HnswGraphBuilder<?> create( public static <T> HnswGraphBuilder<T> create(
RandomAccessVectorValues vectors, RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding, VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
int M, int M,
@ -89,7 +89,7 @@ public final class HnswGraphBuilder<T> {
* to ensure repeatable construction. * to ensure repeatable construction.
*/ */
private HnswGraphBuilder( private HnswGraphBuilder(
RandomAccessVectorValues vectors, RandomAccessVectorValues<T> vectors,
VectorEncoding vectorEncoding, VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
int M, int M,
@ -131,7 +131,7 @@ public final class HnswGraphBuilder<T> {
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an * @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
* independent accessor for the vectors * independent accessor for the vectors
*/ */
public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException { public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
if (vectorsToAdd == this.vectors) { if (vectorsToAdd == this.vectors) {
throw new IllegalArgumentException( throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
@ -143,7 +143,7 @@ public final class HnswGraphBuilder<T> {
return hnsw; return hnsw;
} }
private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException { private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
long start = System.nanoTime(), t = start; long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor // start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectorsToAdd.size(); node++) { for (int node = 1; node < vectorsToAdd.size(); node++) {
@ -189,16 +189,8 @@ public final class HnswGraphBuilder<T> {
} }
} }
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException { public void addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException {
addGraphNode(node, getValue(node, values)); addGraphNode(node, values.vectorValue(node));
}
@SuppressWarnings("unchecked")
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
return switch (vectorEncoding) {
case BYTE -> (T) values.binaryValue(node);
case FLOAT32 -> (T) values.vectorValue(node);
};
} }
private long printGraphBuildStatus(int node, long start, long t) { private long printGraphBuildStatus(int node, long start, long t) {
@ -281,8 +273,8 @@ public final class HnswGraphBuilder<T> {
private boolean isDiverse(int candidate, NeighborArray neighbors, float score) private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException { throws IOException {
return switch (vectorEncoding) { return switch (vectorEncoding) {
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score); case BYTE -> isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score); case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
}; };
} }
@ -290,7 +282,8 @@ public final class HnswGraphBuilder<T> {
throws IOException { throws IOException {
for (int i = 0; i < neighbors.size(); i++) { for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity = float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i])); similarityFunction.compare(
candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) { if (neighborSimilarity >= score) {
return false; return false;
} }
@ -302,7 +295,8 @@ public final class HnswGraphBuilder<T> {
throws IOException { throws IOException {
for (int i = 0; i < neighbors.size(); i++) { for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity = float neighborSimilarity =
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i])); similarityFunction.compare(
candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) { if (neighborSimilarity >= score) {
return false; return false;
} }
@ -327,9 +321,10 @@ public final class HnswGraphBuilder<T> {
throws IOException { throws IOException {
int candidateNode = neighbors.node[candidateIndex]; int candidateNode = neighbors.node[candidateIndex];
return switch (vectorEncoding) { return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors); case BYTE -> isWorstNonDiverse(
candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
case FLOAT32 -> isWorstNonDiverse( case FLOAT32 -> isWorstNonDiverse(
candidateIndex, vectors.vectorValue(candidateNode), neighbors); candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
}; };
} }
@ -338,7 +333,8 @@ public final class HnswGraphBuilder<T> {
float minAcceptedSimilarity = neighbors.score[candidateIndex]; float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) { for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity = float neighborSimilarity =
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i])); similarityFunction.compare(
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node // candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) { if (neighborSimilarity >= minAcceptedSimilarity) {
return true; return true;
@ -352,7 +348,8 @@ public final class HnswGraphBuilder<T> {
float minAcceptedSimilarity = neighbors.score[candidateIndex]; float minAcceptedSimilarity = neighbors.score[candidateIndex];
for (int i = candidateIndex - 1; i >= 0; i--) { for (int i = candidateIndex - 1; i >= 0; i--) {
float neighborSimilarity = float neighborSimilarity =
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i])); similarityFunction.compare(
candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
// candidate node is too similar to node i given its score relative to the base node // candidate node is too similar to node i given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) { if (neighborSimilarity >= minAcceptedSimilarity) {
return true; return true;

View File

@ -81,7 +81,7 @@ public class HnswGraphSearcher<T> {
public static NeighborQueue search( public static NeighborQueue search(
float[] query, float[] query,
int topK, int topK,
RandomAccessVectorValues vectors, RandomAccessVectorValues<float[]> vectors,
VectorEncoding vectorEncoding, VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
HnswGraph graph, HnswGraph graph,
@ -137,7 +137,7 @@ public class HnswGraphSearcher<T> {
public static NeighborQueue search( public static NeighborQueue search(
BytesRef query, BytesRef query,
int topK, int topK,
RandomAccessVectorValues vectors, RandomAccessVectorValues<BytesRef> vectors,
VectorEncoding vectorEncoding, VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction, VectorSimilarityFunction similarityFunction,
HnswGraph graph, HnswGraph graph,
@ -198,7 +198,7 @@ public class HnswGraphSearcher<T> {
int topK, int topK,
int level, int level,
final int[] eps, final int[] eps,
RandomAccessVectorValues vectors, RandomAccessVectorValues<T> vectors,
HnswGraph graph) HnswGraph graph)
throws IOException { throws IOException {
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE); return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
@ -209,7 +209,7 @@ public class HnswGraphSearcher<T> {
int topK, int topK,
int level, int level,
final int[] eps, final int[] eps,
RandomAccessVectorValues vectors, RandomAccessVectorValues<T> vectors,
HnswGraph graph, HnswGraph graph,
Bits acceptOrds, Bits acceptOrds,
int visitedLimit) int visitedLimit)
@ -279,11 +279,11 @@ public class HnswGraphSearcher<T> {
return results; return results;
} }
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException { private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
if (vectorEncoding == VectorEncoding.BYTE) { if (vectorEncoding == VectorEncoding.BYTE) {
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord)); return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
} else { } else {
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord)); return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
} }
} }

View File

@ -18,7 +18,6 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.util.BytesRef;
/** /**
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
@ -26,7 +25,7 @@ import org.apache.lucene.util.BytesRef;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public interface RandomAccessVectorValues { public interface RandomAccessVectorValues<T> {
/** Return the number of vector values */ /** Return the number of vector values */
int size(); int size();
@ -35,26 +34,16 @@ public interface RandomAccessVectorValues {
int dimension(); int dimension();
/** /**
* Return the vector value indexed at the given ordinal. The provided floating point array may be * Return the vector value indexed at the given ordinal.
* shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}.
* *
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}. * @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/ */
float[] vectorValue(int targetOrd) throws IOException; T vectorValue(int targetOrd) throws IOException;
/**
* Return the vector indexed at the given ordinal value as an array of bytes in a BytesRef; these
* are the bytes corresponding to the float array. The provided bytes may be shared and
* overwritten by subsequent calls to this method and {@link #vectorValue(int)}.
*
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
BytesRef binaryValue(int targetOrd) throws IOException;
/** /**
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to * Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
* access different values at once, to avoid overwriting the underlying float vector returned by * access different values at once, to avoid overwriting the underlying float vector returned by
* {@link RandomAccessVectorValues#vectorValue}. * {@link RandomAccessVectorValues#vectorValue}.
*/ */
RandomAccessVectorValues copy() throws IOException; RandomAccessVectorValues<T> copy() throws IOException;
} }

View File

@ -21,6 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.StringReader; import java.io.StringReader;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
@ -611,25 +612,22 @@ public class TestField extends LuceneTestCase {
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document(); Document doc = new Document();
BytesRef br = newBytesRef(new byte[5]); BytesRef br = newBytesRef(new byte[5]);
Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN); Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
expectThrows( expectThrows(
IllegalArgumentException.class, IllegalArgumentException.class,
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType())); () -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
float[] vector = new float[] {1, 2}; float[] vector = new float[] {1, 2};
Field field2 = new KnnVectorField("float", vector); Field field2 = new KnnVectorField("float", vector);
expectThrows(
IllegalArgumentException.class,
() -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
assertEquals(br, field.binaryValue()); assertEquals(br, field.binaryValue());
doc.add(field); doc.add(field);
doc.add(field2); doc.add(field2);
w.addDocument(doc); w.addDocument(doc);
try (IndexReader r = DirectoryReader.open(w)) { try (IndexReader r = DirectoryReader.open(w)) {
VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary"); ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
assertEquals(1, binary.size()); assertEquals(1, binary.size());
assertNotEquals(NO_MORE_DOCS, binary.nextDoc()); assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
assertEquals(br, binary.binaryValue());
assertNotNull(binary.vectorValue()); assertNotNull(binary.vectorValue());
assertEquals(br, binary.vectorValue());
assertEquals(NO_MORE_DOCS, binary.nextDoc()); assertEquals(NO_MORE_DOCS, binary.nextDoc());
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float"); VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");

View File

@ -112,6 +112,11 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
return null; return null;
} }
@Override
public ByteVectorValues getByteVectorValues(String field) {
return null;
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {

View File

@ -17,7 +17,7 @@
package org.apache.lucene.search; package org.apache.lucene.search;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
@ -27,12 +27,12 @@ import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override @Override
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter); return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
} }
@Override @Override
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query); return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
} }
@Override @Override
@ -49,12 +49,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@Override @Override
Field getKnnVectorField( Field getKnnVectorField(
String name, float[] vector, VectorSimilarityFunction similarityFunction) { String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction); return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
} }
@Override @Override
Field getKnnVectorField(String name, float[] vector) { Field getKnnVectorField(String name, float[] vector) {
return new KnnVectorField( return new KnnByteVectorField(
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN); name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
} }
@ -80,7 +80,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
super(field, target, k, filter); super(field, target, k, filter);
} }

View File

@ -22,6 +22,7 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
@ -79,7 +80,7 @@ public class TestVectorScorer extends LuceneTestCase {
for (int j = 0; j < v.length; j++) { for (int j = 0; j < v.length; j++) {
v.bytes[j] = (byte) contents[i][j]; v.bytes[j] = (byte) contents[i][j];
} }
doc.add(new KnnVectorField(field, v, EUCLIDEAN)); doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
} else { } else {
doc.add(new KnnVectorField(field, contents[i])); doc.add(new KnnVectorField(field, contents[i]));
} }

View File

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import org.apache.lucene.util.BytesRef;
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T> {
protected final int dimension;
protected final T[] denseValues;
protected final T[] values;
protected final int numVectors;
protected final BytesRef binaryValue;
protected int pos = -1;
AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
this.dimension = dimension;
this.values = values;
this.denseValues = denseValues;
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension);
binaryValue.length = dimension;
this.numVectors = numVectors;
}
@Override
public int size() {
return numVectors;
}
@Override
public int dimension() {
return dimension;
}
@Override
public T vectorValue(int targetOrd) {
return denseValues[targetOrd];
}
@Override
public abstract AbstractMockVectorValues<T> copy();
public abstract T vectorValue() throws IOException;
private boolean seek(int target) {
if (target >= 0 && target < values.length && values[target] != null) {
pos = target;
return true;
} else {
return false;
}
}
public int docID() {
return pos;
}
public int nextDoc() {
return advance(pos + 1);
}
public int advance(int target) {
while (++pos < values.length) {
if (seek(pos)) {
return pos;
}
}
return NO_MORE_DOCS;
}
}

View File

@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed; import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
import static org.apache.lucene.util.VectorUtil.toBytesRef;
import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException; import java.io.IOException;
@ -36,21 +35,23 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader; import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.Field;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.StoredFields;
import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues; import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnVectorQuery; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortField;
@ -65,19 +66,30 @@ import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.junit.Before;
/** Tests HNSW KNN graphs */ /** Tests HNSW KNN graphs */
public class TestHnswGraph extends LuceneTestCase { abstract class HnswGraphTestCase<T> extends LuceneTestCase {
VectorSimilarityFunction similarityFunction; VectorSimilarityFunction similarityFunction;
VectorEncoding vectorEncoding;
@Before abstract VectorEncoding getVectorEncoding();
public void setup() {
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); abstract Query knnQuery(String field, T vector, int k);
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
} abstract T randomVector(int dim);
abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
throws IOException;
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
abstract T getTargetVector();
// test writing out and reading in a graph gives the expected graph // test writing out and reading in a graph gives the expected graph
public void testReadWrite() throws IOException { public void testReadWrite() throws IOException {
@ -86,10 +98,11 @@ public class TestHnswGraph extends LuceneTestCase {
int M = random().nextInt(4) + 2; int M = random().nextInt(4) + 2;
int beamWidth = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5;
long seed = random().nextLong(); long seed = random().nextLong();
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random()); AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy(); AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed); HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors.copy()); HnswGraph hnsw = builder.build(vectors.copy());
// Recreate the graph while indexing with the same random seed and write it out // Recreate the graph while indexing with the same random seed and write it out
@ -115,7 +128,7 @@ public class TestHnswGraph extends LuceneTestCase {
indexedDoc++; indexedDoc++;
} }
Document doc = new Document(); Document doc = new Document();
doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction)); doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
doc.add(new StoredField("id", v2.docID())); doc.add(new StoredField("id", v2.docID()));
iw.addDocument(doc); iw.addDocument(doc);
nVec++; nVec++;
@ -124,7 +137,7 @@ public class TestHnswGraph extends LuceneTestCase {
} }
try (IndexReader reader = DirectoryReader.open(dir)) { try (IndexReader reader = DirectoryReader.open(dir)) {
for (LeafReaderContext ctx : reader.leaves()) { for (LeafReaderContext ctx : reader.leaves()) {
VectorValues values = ctx.reader().getVectorValues("field"); AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
assertEquals(dim, values.dimension()); assertEquals(dim, values.dimension());
assertEquals(nVec, values.size()); assertEquals(nVec, values.size());
assertEquals(indexedDoc, ctx.reader().maxDoc()); assertEquals(indexedDoc, ctx.reader().maxDoc());
@ -142,15 +155,11 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
private VectorEncoding randomVectorEncoding() {
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
}
// test that sorted index returns the same search results are unsorted // test that sorted index returns the same search results are unsorted
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
int dim = random().nextInt(10) + 3; int dim = random().nextInt(10) + 3;
int nDoc = random().nextInt(200) + 100; int nDoc = random().nextInt(200) + 100;
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random()); AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
int M = random().nextInt(10) + 5; int M = random().nextInt(10) + 5;
int beamWidth = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 5;
@ -190,7 +199,7 @@ public class TestHnswGraph extends LuceneTestCase {
indexedDoc++; indexedDoc++;
} }
Document doc = new Document(); Document doc = new Document();
doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction)); doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
doc.add(new StoredField("id", vectors.docID())); doc.add(new StoredField("id", vectors.docID()));
doc.add(new NumericDocValuesField("sortkey", random().nextLong())); doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
iw.addDocument(doc); iw.addDocument(doc);
@ -206,7 +215,7 @@ public class TestHnswGraph extends LuceneTestCase {
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
// ask to explore a lot of candidates to ensure the same returned hits, // ask to explore a lot of candidates to ensure the same returned hits,
// as graphs of 2 indices are organized differently // as graphs of 2 indices are organized differently
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50); Query query = knnQuery("vector", randomVector(dim), 50);
List<String> ids1 = new ArrayList<>(); List<String> ids1 = new ArrayList<>();
List<Integer> docs1 = new ArrayList<>(); List<Integer> docs1 = new ArrayList<>();
List<String> ids2 = new ArrayList<>(); List<String> ids2 = new ArrayList<>();
@ -241,7 +250,7 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels()); assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size()); assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
@ -271,32 +280,32 @@ public class TestHnswGraph extends LuceneTestCase {
// Make sure we actually approximately find the closest k elements. Mostly this is about // Make sure we actually approximately find the closest k elements. Mostly this is about
// ensuring that we have all the distance functions, comparators, priority queues and so on // ensuring that we have all the distance functions, comparators, priority queues and so on
// oriented in the right directions // oriented in the right directions
@SuppressWarnings("unchecked")
public void testAknnDiverse() throws IOException { public void testAknnDiverse() throws IOException {
int nDoc = 100; int nDoc = 100;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc); RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// run some searches // run some searches
NeighborQueue nn = NeighborQueue nn =
switch (vectorEncoding) { switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search( case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(), (BytesRef) getTargetVector(),
10, 10,
vectors.copy(), (RandomAccessVectorValues<BytesRef>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
null, null,
Integer.MAX_VALUE); Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search( case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(), (float[]) getTargetVector(),
10, 10,
vectors.copy(), (RandomAccessVectorValues<float[]>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
null, null,
@ -323,33 +332,33 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
@SuppressWarnings("unchecked")
public void testSearchWithAcceptOrds() throws IOException { public void testSearchWithAcceptOrds() throws IOException {
int nDoc = 100; int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc); RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
vectorEncoding = randomVectorEncoding(); HnswGraphBuilder<T> builder =
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// the first 10 docs must not be deleted to ensure the expected recall // the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size); Bits acceptOrds = createRandomAcceptOrds(10, nDoc);
NeighborQueue nn = NeighborQueue nn =
switch (vectorEncoding) { switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search( case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(), (BytesRef) getTargetVector(),
10, 10,
vectors.copy(), (RandomAccessVectorValues<BytesRef>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search( case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(), (float[]) getTargetVector(),
10, 10,
vectors.copy(), (RandomAccessVectorValues<float[]>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
@ -367,39 +376,39 @@ public class TestHnswGraph extends LuceneTestCase {
assertTrue("sum(result docs)=" + sum, sum < 75); assertTrue("sum(result docs)=" + sum, sum < 75);
} }
@SuppressWarnings("unchecked")
public void testSearchWithSelectiveAcceptOrds() throws IOException { public void testSearchWithSelectiveAcceptOrds() throws IOException {
int nDoc = 100; int nDoc = 100;
CircularVectorValues vectors = new CircularVectorValues(nDoc); RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Only mark a few vectors as accepted // Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size); BitSet acceptOrds = new FixedBitSet(nDoc);
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) { for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) {
acceptOrds.set(i); acceptOrds.set(i);
} }
// Check the search finds all accepted vectors // Check the search finds all accepted vectors
int numAccepted = acceptOrds.cardinality(); int numAccepted = acceptOrds.cardinality();
NeighborQueue nn = NeighborQueue nn =
switch (vectorEncoding) { switch (getVectorEncoding()) {
case FLOAT32 -> HnswGraphSearcher.search( case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(), (float[]) getTargetVector(),
numAccepted, numAccepted,
vectors.copy(), (RandomAccessVectorValues<float[]>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
case BYTE -> HnswGraphSearcher.search( case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(), (BytesRef) getTargetVector(),
numAccepted, numAccepted,
vectors.copy(), (RandomAccessVectorValues<BytesRef>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
@ -413,81 +422,37 @@ public class TestHnswGraph extends LuceneTestCase {
} }
} }
private float[] getTargetVector() { @SuppressWarnings("unchecked")
return new float[] {1, 0};
}
private BytesRef getTargetByteVector() {
return new BytesRef(new byte[] {1, 0});
}
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
CircularVectorValues vectors = new CircularVectorValues(nDoc);
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Skip over half of the documents that are closest to the query vector
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
for (int i = 500; i < nDoc; i++) {
acceptOrds.set(i);
}
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
VectorEncoding.FLOAT32,
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
sum += node;
}
// We still expect to get reasonable recall. The lowest non-skipped docIds
// are closest to the query vector: sum(500,509) = 5045
assertTrue("sum(result docs)=" + sum, sum < 5100);
}
public void testVisitedLimit() throws IOException { public void testVisitedLimit() throws IOException {
int nDoc = 500; int nDoc = 500;
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
CircularVectorValues vectors = new CircularVectorValues(nDoc); RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
int topK = 50; int topK = 50;
int visitedLimit = topK + random().nextInt(5); int visitedLimit = topK + random().nextInt(5);
NeighborQueue nn = NeighborQueue nn =
switch (vectorEncoding) { switch (getVectorEncoding()) {
case FLOAT32 -> HnswGraphSearcher.search( case FLOAT32 -> HnswGraphSearcher.search(
getTargetVector(), (float[]) getTargetVector(),
topK, topK,
vectors.copy(), (RandomAccessVectorValues<float[]>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
createRandomAcceptOrds(0, vectors.size), createRandomAcceptOrds(0, nDoc),
visitedLimit); visitedLimit);
case BYTE -> HnswGraphSearcher.search( case BYTE -> HnswGraphSearcher.search(
getTargetByteVector(), (BytesRef) getTargetVector(),
topK, topK,
vectors.copy(), (RandomAccessVectorValues<BytesRef>) vectors.copy(),
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
createRandomAcceptOrds(0, vectors.size), createRandomAcceptOrds(0, nDoc),
visitedLimit); visitedLimit);
}; };
@ -504,8 +469,8 @@ public class TestHnswGraph extends LuceneTestCase {
IllegalArgumentException.class, IllegalArgumentException.class,
() -> () ->
HnswGraphBuilder.create( HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()), vectorValues(1, 1),
VectorEncoding.FLOAT32, getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
0, 0,
10, 10,
@ -515,8 +480,8 @@ public class TestHnswGraph extends LuceneTestCase {
IllegalArgumentException.class, IllegalArgumentException.class,
() -> () ->
HnswGraphBuilder.create( HnswGraphBuilder.create(
new RandomVectorValues(1, 1, random()), vectorValues(1, 1),
VectorEncoding.FLOAT32, getVectorEncoding(),
VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.EUCLIDEAN,
10, 10,
0, 0,
@ -530,13 +495,11 @@ public class TestHnswGraph extends LuceneTestCase {
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction similarityFunction =
RandomizedTest.randomFrom(VectorSimilarityFunction.values()); RandomizedTest.randomFrom(VectorSimilarityFunction.values());
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values()); RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
TestHnswGraph.RandomVectorValues vectors =
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong()); vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
long estimated = RamUsageEstimator.sizeOfObject(hnsw); long estimated = RamUsageEstimator.sizeOfObject(hnsw);
long actual = ramUsed(hnsw); long actual = ramUsed(hnsw);
@ -546,7 +509,6 @@ public class TestHnswGraph extends LuceneTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testDiversity() throws IOException { public void testDiversity() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
// Some carefully checked test cases with simple 2d vectors on the unit circle: // Some carefully checked test cases with simple 2d vectors on the unit circle:
float[][] values = { float[][] values = {
@ -558,21 +520,14 @@ public class TestHnswGraph extends LuceneTestCase {
unitVector2d(0.77), unitVector2d(0.77),
unitVector2d(0.6) unitVector2d(0.6)
}; };
if (vectorEncoding == VectorEncoding.BYTE) { AbstractMockVectorValues<T> vectors = vectorValues(values);
for (float[] v : values) {
for (int i = 0; i < v.length; i++) {
v[i] *= 127;
}
}
}
MockVectorValues vectors = new MockVectorValues(values);
// First add nodes until everybody gets a full neighbor list // First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
// node 0 is added by the builder constructor // node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0)); // builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy(); RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy); builder.addGraphNode(2, vectorsCopy);
// now every node has tried to attach every other node as a neighbor, but // now every node has tried to attach every other node as a neighbor, but
@ -609,7 +564,6 @@ public class TestHnswGraph extends LuceneTestCase {
} }
public void testDiversityFallback() throws IOException { public void testDiversityFallback() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// Some test cases can't be exercised in two dimensions; // Some test cases can't be exercised in two dimensions;
// in particular if a new neighbor displaces an existing neighbor // in particular if a new neighbor displaces an existing neighbor
@ -622,14 +576,14 @@ public class TestHnswGraph extends LuceneTestCase {
{10, 0, 0}, {10, 0, 0},
{0, 4, 0} {0, 4, 0}
}; };
MockVectorValues vectors = new MockVectorValues(values); AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list // First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor // node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0)); // builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy(); RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy); builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
@ -647,7 +601,6 @@ public class TestHnswGraph extends LuceneTestCase {
} }
public void testDiversity3d() throws IOException { public void testDiversity3d() throws IOException {
vectorEncoding = randomVectorEncoding();
similarityFunction = VectorSimilarityFunction.EUCLIDEAN; similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives // test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
float[][] values = { float[][] values = {
@ -656,14 +609,14 @@ public class TestHnswGraph extends LuceneTestCase {
{0, 0, 20}, {0, 0, 20},
{0, 9, 0} {0, 9, 0}
}; };
MockVectorValues vectors = new MockVectorValues(values); AbstractMockVectorValues<T> vectors = vectorValues(values);
// First add nodes until everybody gets a full neighbor list // First add nodes until everybody gets a full neighbor list
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt()); vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
// node 0 is added by the builder constructor // node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0)); // builder.addGraphNode(vectors.vectorValue(0));
RandomAccessVectorValues vectorsCopy = vectors.copy(); RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy); builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy); builder.addGraphNode(2, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2); assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
@ -691,44 +644,38 @@ public class TestHnswGraph extends LuceneTestCase {
actual); actual);
} }
@SuppressWarnings("unchecked")
public void testRandom() throws IOException { public void testRandom() throws IOException {
int size = atLeast(100); int size = atLeast(100);
int dim = atLeast(10); int dim = atLeast(10);
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random()); AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
int topK = 5; int topK = 5;
HnswGraphBuilder<?> builder = HnswGraphBuilder<T> builder =
HnswGraphBuilder.create( HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong()); vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors.copy()); OnHeapHnswGraph hnsw = builder.build(vectors.copy());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size); Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0; int totalMatches = 0;
for (int i = 0; i < 100; i++) { for (int i = 0; i < 100; i++) {
NeighborQueue actual; NeighborQueue actual;
float[] query; T query = randomVector(dim);
BytesRef bQuery = null;
if (vectorEncoding == VectorEncoding.BYTE) {
query = randomVector8(random(), dim);
bQuery = toBytesRef(query);
} else {
query = randomVector(random(), dim);
}
actual = actual =
switch (vectorEncoding) { switch (getVectorEncoding()) {
case BYTE -> HnswGraphSearcher.search( case BYTE -> HnswGraphSearcher.search(
bQuery, (BytesRef) query,
100, 100,
vectors, (RandomAccessVectorValues<BytesRef>) vectors,
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
Integer.MAX_VALUE); Integer.MAX_VALUE);
case FLOAT32 -> HnswGraphSearcher.search( case FLOAT32 -> HnswGraphSearcher.search(
query, (float[]) query,
100, 100,
vectors, (RandomAccessVectorValues<float[]>) vectors,
vectorEncoding, getVectorEncoding(),
similarityFunction, similarityFunction,
hnsw, hnsw,
acceptOrds, acceptOrds,
@ -741,10 +688,14 @@ public class TestHnswGraph extends LuceneTestCase {
NeighborQueue expected = new NeighborQueue(topK, false); NeighborQueue expected = new NeighborQueue(topK, false);
for (int j = 0; j < size; j++) { for (int j = 0; j < size; j++) {
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
if (vectorEncoding == VectorEncoding.BYTE) { if (getVectorEncoding() == VectorEncoding.BYTE) {
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j))); assert query instanceof BytesRef;
expected.add(
j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
} else { } else {
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j))); assert query instanceof float[];
expected.add(
j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
} }
if (expected.size() > topK) { if (expected.size() > topK) {
expected.pop(); expected.pop();
@ -778,17 +729,16 @@ public class TestHnswGraph extends LuceneTestCase {
} }
/** Returns vectors evenly distributed around the upper unit semicircle. */ /** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues { static class CircularVectorValues extends VectorValues
implements RandomAccessVectorValues<float[]> {
private final int size; private final int size;
private final float[] value; private final float[] value;
private final BytesRef binaryValue;
int doc = -1; int doc = -1;
CircularVectorValues(int size) { CircularVectorValues(int size) {
this.size = size; this.size = size;
value = new float[2]; value = new float[2];
binaryValue = new BytesRef(new byte[2]);
} }
@Override @Override
@ -835,14 +785,70 @@ public class TestHnswGraph extends LuceneTestCase {
public float[] vectorValue(int ord) { public float[] vectorValue(int ord) {
return unitVector2d(ord / (double) size, value); return unitVector2d(ord / (double) size, value);
} }
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularByteVectorValues extends ByteVectorValues
implements RandomAccessVectorValues<BytesRef> {
private final int size;
private final float[] value;
private final BytesRef bValue;
int doc = -1;
CircularByteVectorValues(int size) {
this.size = size;
value = new float[2];
bValue = new BytesRef(new byte[2]);
}
@Override @Override
public BytesRef binaryValue(int ord) { public CircularByteVectorValues copy() {
float[] vectorValue = vectorValue(ord); return new CircularByteVectorValues(size);
for (int i = 0; i < vectorValue.length; i++) { }
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
@Override
public int dimension() {
return 2;
}
@Override
public int size() {
return size;
}
@Override
public BytesRef vectorValue() {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() {
return advance(doc + 1);
}
@Override
public int advance(int target) {
if (target >= 0 && target < size) {
doc = target;
} else {
doc = NO_MORE_DOCS;
} }
return binaryValue; return doc;
}
@Override
public BytesRef vectorValue(int ord) {
unitVector2d(ord / (double) size, value);
for (int i = 0; i < value.length; i++) {
bValue.bytes[i] = (byte) (value[i] * 127);
}
return bValue;
} }
} }
@ -864,7 +870,8 @@ public class TestHnswGraph extends LuceneTestCase {
return neighbors; return neighbors;
} }
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException { void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
throws IOException {
int uDoc, vDoc; int uDoc, vDoc;
while (true) { while (true) {
uDoc = u.nextDoc(); uDoc = u.nextDoc();
@ -873,49 +880,40 @@ public class TestHnswGraph extends LuceneTestCase {
if (uDoc == NO_MORE_DOCS) { if (uDoc == NO_MORE_DOCS) {
break; break;
} }
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f; switch (getVectorEncoding()) {
assertArrayEquals( case BYTE:
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta); assertArrayEquals(
"vectors do not match for doc=" + uDoc,
((BytesRef) u.vectorValue()).bytes,
((BytesRef) v.vectorValue()).bytes);
break;
case FLOAT32:
assertArrayEquals(
"vectors do not match for doc=" + uDoc,
(float[]) u.vectorValue(),
(float[]) v.vectorValue(),
1e-4f);
break;
default:
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
}
} }
} }
/** Produces random vectors and caches them for random-access. */ static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
static class RandomVectorValues extends MockVectorValues { float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
RandomVectorValues(int size, int dimension, Random random) { vectors[offset] = randomVector(random, dimension);
super(createRandomVectors(size, dimension, null, random));
} }
return vectors;
}
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) { static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
super(createRandomVectors(size, dimension, vectorEncoding, random)); byte[][] vectors = new byte[size][];
} for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector8(random, dimension);
RandomVectorValues(RandomVectorValues other) {
super(other.values);
}
@Override
public RandomVectorValues copy() {
return new RandomVectorValues(this);
}
private static float[][] createRandomVectors(
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
float[][] vectors = new float[size][];
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
vectors[offset] = randomVector(random, dimension);
}
if (vectorEncoding == VectorEncoding.BYTE) {
for (float[] vector : vectors) {
if (vector != null) {
for (int i = 0; i < vector.length; i++) {
vector[i] = (byte) (127 * vector[i]);
}
}
}
}
return vectors;
} }
return vectors;
} }
/** /**
@ -937,7 +935,7 @@ public class TestHnswGraph extends LuceneTestCase {
return bits; return bits;
} }
private static float[] randomVector(Random random, int dim) { static float[] randomVector(Random random, int dim) {
float[] vec = new float[dim]; float[] vec = new float[dim];
for (int i = 0; i < dim; i++) { for (int i = 0; i < dim; i++) {
vec[i] = random.nextFloat(); vec[i] = random.nextFloat();
@ -949,11 +947,12 @@ public class TestHnswGraph extends LuceneTestCase {
return vec; return vec;
} }
private static float[] randomVector8(Random random, int dim) { static byte[] randomVector8(Random random, int dim) {
float[] fvec = randomVector(random, dim); float[] fvec = randomVector(random, dim);
byte[] bvec = new byte[dim];
for (int i = 0; i < dim; i++) { for (int i = 0; i < dim; i++) {
fvec[i] *= 127; bvec[i] = (byte) (fvec[i] * 127);
} }
return fvec; return bvec;
} }
} }

View File

@ -45,6 +45,7 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StoredField;
import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.CodecReader;
@ -704,7 +705,11 @@ public class KnnGraphTester {
iwc.setUseCompoundFile(false); iwc.setUseCompoundFile(false);
// iwc.setMaxBufferedDocs(10000); // iwc.setMaxBufferedDocs(10000);
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction); FieldType fieldType =
switch (vectorEncoding) {
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
case FLOAT32 -> KnnVectorField.createFieldType(dim, similarityFunction);
};
if (quiet == false) { if (quiet == false) {
iwc.setInfoStream(new PrintStreamInfoStream(System.out)); iwc.setInfoStream(new PrintStreamInfoStream(System.out));
System.out.println("creating index in " + indexPath); System.out.println("creating index in " + indexPath);
@ -718,7 +723,7 @@ public class KnnGraphTester {
Document doc = new Document(); Document doc = new Document();
switch (vectorEncoding) { switch (vectorEncoding) {
case BYTE -> doc.add( case BYTE -> doc.add(
new KnnVectorField( new KnnByteVectorField(
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType)); KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType)); case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
} }

View File

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
private final byte[] scratch;
static MockByteVectorValues fromValues(byte[][] byteValues) {
int dimension = byteValues[0].length;
BytesRef[] values = new BytesRef[byteValues.length];
for (int i = 0; i < byteValues.length; i++) {
values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
}
BytesRef[] denseValues = new BytesRef[values.length];
int count = 0;
for (int i = 0; i < byteValues.length; i++) {
if (values[i] != null) {
denseValues[count++] = values[i];
}
}
return new MockByteVectorValues(values, dimension, denseValues, count);
}
MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
super(values, dimension, denseValues, numVectors);
scratch = new byte[dimension];
}
@Override
public MockByteVectorValues copy() {
return new MockByteVectorValues(
ArrayUtil.copyOfSubArray(values, 0, values.length),
dimension,
ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
numVectors);
}
@Override
public BytesRef vectorValue() {
if (LuceneTestCase.random().nextBoolean()) {
return values[pos];
} else {
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
// twice in a
// single computation.
System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
return new BytesRef(scratch);
}
}
}

View File

@ -17,52 +17,37 @@
package org.apache.lucene.util.hnsw; package org.apache.lucene.util.hnsw;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.ArrayUtil;
class MockVectorValues extends VectorValues implements RandomAccessVectorValues { class MockVectorValues extends AbstractMockVectorValues<float[]> {
private final float[] scratch; private final float[] scratch;
protected final int dimension; static MockVectorValues fromValues(float[][] values) {
protected final float[][] denseValues; int dimension = values[0].length;
protected final float[][] values;
private final int numVectors;
private final BytesRef binaryValue;
private int pos = -1;
MockVectorValues(float[][] values) {
this.dimension = values[0].length;
this.values = values;
int maxDoc = values.length; int maxDoc = values.length;
denseValues = new float[maxDoc][]; float[][] denseValues = new float[maxDoc][];
int count = 0; int count = 0;
for (int i = 0; i < maxDoc; i++) { for (int i = 0; i < maxDoc; i++) {
if (values[i] != null) { if (values[i] != null) {
denseValues[count++] = values[i]; denseValues[count++] = values[i];
} }
} }
numVectors = count; return new MockVectorValues(values, dimension, denseValues, count);
scratch = new float[dimension]; }
// used by tests that build a graph from bytes rather than floats
binaryValue = new BytesRef(dimension); MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
binaryValue.length = dimension; super(values, dimension, denseValues, numVectors);
this.scratch = new float[dimension];
} }
@Override @Override
public MockVectorValues copy() { public MockVectorValues copy() {
return new MockVectorValues(values); return new MockVectorValues(
} ArrayUtil.copyOfSubArray(values, 0, values.length),
dimension,
@Override ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
public int size() { numVectors);
return numVectors;
}
@Override
public int dimension() {
return dimension;
} }
@Override @Override
@ -83,42 +68,4 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
public float[] vectorValue(int targetOrd) { public float[] vectorValue(int targetOrd) {
return denseValues[targetOrd]; return denseValues[targetOrd];
} }
@Override
public BytesRef binaryValue(int targetOrd) {
float[] value = vectorValue(targetOrd);
for (int i = 0; i < value.length; i++) {
binaryValue.bytes[i] = (byte) value[i];
}
return binaryValue;
}
private boolean seek(int target) {
if (target >= 0 && target < values.length && values[target] != null) {
pos = target;
return true;
} else {
return false;
}
}
@Override
public int docID() {
return pos;
}
@Override
public int nextDoc() {
return advance(pos + 1);
}
@Override
public int advance(int target) {
while (++pos < values.length) {
if (seek(pos)) {
return pos;
}
}
return NO_MORE_DOCS;
}
} }

View File

@ -0,0 +1,118 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.junit.Before;
/** Tests HNSW KNN graphs */
public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
@Before
public void setup() {
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
}
@Override
VectorEncoding getVectorEncoding() {
return VectorEncoding.BYTE;
}
@Override
Query knnQuery(String field, BytesRef vector, int k) {
return new KnnByteVectorQuery(field, vector, k);
}
@Override
BytesRef randomVector(int dim) {
return new BytesRef(randomVector8(random(), dim));
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
}
static boolean fitsInByte(float v) {
return v <= 127 && v >= -128 && v % 1 == 0;
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
byte[][] bValues = new byte[values.length][];
// The case when all floats fit within a byte already.
boolean scaleSimple = fitsInByte(values[0][0]);
for (int i = 0; i < values.length; i++) {
bValues[i] = new byte[values[i].length];
for (int j = 0; j < values[i].length; j++) {
final float v;
if (scaleSimple) {
assert fitsInByte(values[i][j]);
v = values[i][j];
} else {
v = values[i][j] * 127;
}
bValues[i][j] = (byte) v;
}
}
return MockByteVectorValues.fromValues(bValues);
}
@Override
AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
throws IOException {
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
byte[][] vectors = new byte[reader.maxDoc()][];
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
vectors[vectorValues.docID()] =
ArrayUtil.copyOfSubArray(
vectorValues.vectorValue().bytes,
vectorValues.vectorValue().offset,
vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
}
return MockByteVectorValues.fromValues(vectors);
}
@Override
Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
return new KnnByteVectorField(name, vector, similarityFunction);
}
@Override
RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
return new CircularByteVectorValues(nDoc);
}
@Override
BytesRef getTargetVector() {
return new BytesRef(new byte[] {1, 0});
}
}

View File

@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import com.carrotsearch.randomizedtesting.RandomizedTest;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.KnnVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.FixedBitSet;
import org.junit.Before;
/** Tests HNSW KNN graphs */
public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
@Before
public void setup() {
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
}
@Override
VectorEncoding getVectorEncoding() {
return VectorEncoding.FLOAT32;
}
@Override
Query knnQuery(String field, float[] vector, int k) {
return new KnnVectorQuery(field, vector, k);
}
@Override
float[] randomVector(int dim) {
return randomVector(random(), dim);
}
@Override
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
}
@Override
AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
return MockVectorValues.fromValues(values);
}
@Override
AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
throws IOException {
VectorValues vectorValues = reader.getVectorValues(fieldName);
float[][] vectors = new float[reader.maxDoc()][];
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
vectors[vectorValues.docID()] =
ArrayUtil.copyOfSubArray(
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
}
return MockVectorValues.fromValues(vectors);
}
@Override
Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
return new KnnVectorField(name, vector, similarityFunction);
}
@Override
RandomAccessVectorValues<float[]> circularVectorValues(int nDoc) {
return new CircularVectorValues(nDoc);
}
@Override
float[] getTargetVector() {
return new float[] {1f, 0f};
}
public void testSearchWithSkewedAcceptOrds() throws IOException {
int nDoc = 1000;
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
HnswGraphBuilder<float[]> builder =
HnswGraphBuilder.create(
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Skip over half of the documents that are closest to the query vector
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
for (int i = 500; i < nDoc; i++) {
acceptOrds.set(i);
}
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.copy(),
getVectorEncoding(),
similarityFunction,
hnsw,
acceptOrds,
Integer.MAX_VALUE);
int[] nodes = nn.nodes();
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
int sum = 0;
for (int node : nodes) {
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
sum += node;
}
// We still expect to get reasonable recall. The lowest non-skipped docIds
// are closest to the query vector: sum(500,509) = 5045
assertTrue("sum(result docs)=" + sum, sum < 5100);
}
}

View File

@ -20,6 +20,7 @@ import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
@ -165,6 +166,11 @@ public class TermVectorLeafReader extends LeafReader {
return null; return null;
} }
@Override
public ByteVectorValues getByteVectorValues(String fieldName) {
return null;
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {

View File

@ -1395,6 +1395,11 @@ public class MemoryIndex {
return null; return null;
} }
@Override
public ByteVectorValues getByteVectorValues(String fieldName) {
return null;
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {

View File

@ -22,6 +22,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.MergeState; import org.apache.lucene.index.MergeState;
@ -113,7 +114,9 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
@Override @Override
public VectorValues getVectorValues(String field) throws IOException { public VectorValues getVectorValues(String field) throws IOException {
FieldInfo fi = fis.fieldInfo(field); FieldInfo fi = fis.fieldInfo(field);
assert fi != null && fi.getVectorDimension() > 0; assert fi != null
&& fi.getVectorDimension() > 0
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
VectorValues values = delegate.getVectorValues(field); VectorValues values = delegate.getVectorValues(field);
assert values != null; assert values != null;
assert values.docID() == -1; assert values.docID() == -1;
@ -122,6 +125,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
return values; return values;
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldInfo fi = fis.fieldInfo(field);
assert fi != null
&& fi.getVectorDimension() > 0
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
ByteVectorValues values = delegate.getByteVectorValues(field);
assert values != null;
assert values.docID() == -1;
assert values.size() >= 0;
assert values.dimension() > 0;
return values;
}
@Override @Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit) public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
throws IOException { throws IOException {

View File

@ -28,10 +28,12 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.FieldType; import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField; import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StoredField; import org.apache.lucene.document.StoredField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CheckIndex; import org.apache.lucene.index.CheckIndex;
import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.DirectoryReader;
@ -79,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
@Override @Override
protected void addRandomFields(Document doc) { protected void addRandomFields(Document doc) {
switch (vectorEncoding) { switch (vectorEncoding) {
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction)); case BYTE -> doc.add(
new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction)); case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
} }
} }
@ -628,9 +631,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (random().nextInt(100) == 17) { if (random().nextInt(100) == 17) {
switch (fieldVectorEncodings[field]) { switch (fieldVectorEncodings[field]) {
case BYTE -> { case BYTE -> {
BytesRef b = randomVector8(fieldDims[field]); byte[] b = randomVector8(fieldDims[field]);
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field])); doc.add(
fieldTotals[field] += b.bytes[b.offset]; new KnnByteVectorField(
fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
fieldTotals[field] += b[0];
} }
case FLOAT32 -> { case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]); float[] v = randomVector(fieldDims[field]);
@ -648,12 +653,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
int docCount = 0; int docCount = 0;
double checksum = 0; double checksum = 0;
String fieldName = "int" + field; String fieldName = "int" + field;
for (LeafReaderContext ctx : r.leaves()) { switch (fieldVectorEncodings[field]) {
VectorValues vectors = ctx.reader().getVectorValues(fieldName); case BYTE -> {
if (vectors != null) { for (LeafReaderContext ctx : r.leaves()) {
docCount += vectors.size(); ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
while (vectors.nextDoc() != NO_MORE_DOCS) { if (byteVectorValues != null) {
checksum += vectors.vectorValue()[0]; docCount += byteVectorValues.size();
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += byteVectorValues.vectorValue().bytes[0];
}
}
}
}
case FLOAT32 -> {
for (LeafReaderContext ctx : r.leaves()) {
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
if (vectorValues != null) {
docCount += vectorValues.size();
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += vectorValues.vectorValue()[0];
}
}
} }
} }
} }
@ -755,15 +775,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
LeafReader leaf = getOnlyLeafReader(reader); LeafReader leaf = getOnlyLeafReader(reader);
StoredFields storedFields = leaf.storedFields(); StoredFields storedFields = leaf.storedFields();
VectorValues vectorValues = leaf.getVectorValues(fieldName); ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
assertEquals(2, vectorValues.dimension()); assertEquals(2, vectorValues.dimension());
assertEquals(3, vectorValues.size()); assertEquals(3, vectorValues.size());
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id")); assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(-1f, vectorValues.vectorValue()[0], 0); assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id")); assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(1, vectorValues.vectorValue()[0], 0); assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id")); assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
assertEquals(0, vectorValues.vectorValue()[0], 0); assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
} }
} }
@ -915,7 +935,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) { if (random().nextInt(7) != 3) {
// usually index a vector value for a doc // usually index a vector value for a doc
values[i] = randomVector8(dimension); values[i] = new BytesRef(randomVector8(dimension));
++numValues; ++numValues;
} }
if (random().nextBoolean() && values[i] != null) { if (random().nextBoolean() && values[i] != null) {
@ -943,7 +963,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (IndexReader reader = DirectoryReader.open(iw)) { try (IndexReader reader = DirectoryReader.open(iw)) {
int valueCount = 0, totalSize = 0; int valueCount = 0, totalSize = 0;
for (LeafReaderContext ctx : reader.leaves()) { for (LeafReaderContext ctx : reader.leaves()) {
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName); ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
if (vectorValues == null) { if (vectorValues == null) {
continue; continue;
} }
@ -951,7 +971,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
StoredFields storedFields = ctx.reader().storedFields(); StoredFields storedFields = ctx.reader().storedFields();
int docId; int docId;
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
BytesRef v = vectorValues.binaryValue(); BytesRef v = vectorValues.vectorValue();
assertEquals(dimension, v.length); assertEquals(dimension, v.length);
String idString = storedFields.document(docId).getField("id").stringValue(); String idString = storedFields.document(docId).getField("id").stringValue();
int id = Integer.parseInt(idString); int id = Integer.parseInt(idString);
@ -1141,7 +1161,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
throws IOException { throws IOException {
Document doc = new Document(); Document doc = new Document();
if (vector != null) { if (vector != null) {
doc.add(new KnnVectorField(field, vector, similarityFunction)); doc.add(new KnnByteVectorField(field, vector, similarityFunction));
} }
doc.add(new NumericDocValuesField("sortkey", sortKey)); doc.add(new NumericDocValuesField("sortkey", sortKey));
String idString = Integer.toString(id); String idString = Integer.toString(id);
@ -1183,13 +1203,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
return v; return v;
} }
private BytesRef randomVector8(int dim) { private byte[] randomVector8(int dim) {
float[] v = randomVector(dim); float[] v = randomVector(dim);
byte[] b = new byte[dim]; byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) { for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127); b[i] = (byte) (v[i] * 127);
} }
return new BytesRef(b); return b;
} }
public void testCheckIndexIncludesVectors() throws Exception { public void testCheckIndexIncludesVectors() throws Exception {
@ -1297,9 +1317,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
if (random().nextInt(4) == 3) { if (random().nextInt(4) == 3) {
switch (vectorEncoding) { switch (vectorEncoding) {
case BYTE -> { case BYTE -> {
BytesRef b = randomVector8(dim); byte[] b = randomVector8(dim);
fieldValuesCheckSum += b.bytes[b.offset]; fieldValuesCheckSum += b[0];
doc.add(new KnnVectorField("knn_vector", b, similarityFunction)); doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
} }
case FLOAT32 -> { case FLOAT32 -> {
float[] v = randomVector(dim); float[] v = randomVector(dim);
@ -1321,15 +1341,33 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
double checksum = 0; double checksum = 0;
int docCount = 0; int docCount = 0;
long sumDocIds = 0; long sumDocIds = 0;
for (LeafReaderContext ctx : r.leaves()) { switch (vectorEncoding) {
VectorValues vectors = ctx.reader().getVectorValues("knn_vector"); case BYTE -> {
if (vectors != null) { for (LeafReaderContext ctx : r.leaves()) {
StoredFields storedFields = ctx.reader().storedFields(); ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues("knn_vector");
docCount += vectors.size(); if (byteVectorValues != null) {
while (vectors.nextDoc() != NO_MORE_DOCS) { docCount += byteVectorValues.size();
checksum += vectors.vectorValue()[0]; StoredFields storedFields = ctx.reader().storedFields();
Document doc = storedFields.document(vectors.docID(), Set.of("id")); while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
sumDocIds += Integer.parseInt(doc.get("id")); checksum += byteVectorValues.vectorValue().bytes[0];
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
sumDocIds += Integer.parseInt(doc.get("id"));
}
}
}
}
case FLOAT32 -> {
for (LeafReaderContext ctx : r.leaves()) {
VectorValues vectorValues = ctx.reader().getVectorValues("knn_vector");
if (vectorValues != null) {
docCount += vectorValues.size();
StoredFields storedFields = ctx.reader().storedFields();
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
checksum += vectorValues.vectorValue()[0];
Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
sumDocIds += Integer.parseInt(doc.get("id"));
}
}
} }
} }
} }

View File

@ -24,6 +24,7 @@ import org.apache.lucene.codecs.NormsProducer;
import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.StoredFieldsReader;
import org.apache.lucene.codecs.TermVectorsReader; import org.apache.lucene.codecs.TermVectorsReader;
import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DocValuesType; import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;
@ -222,6 +223,11 @@ class MergeReaderWrapper extends LeafReader {
return in.getVectorValues(fieldName); return in.getVectorValues(fieldName);
} }
@Override
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
return in.getByteVectorValues(fieldName);
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {

View File

@ -24,6 +24,7 @@ import java.io.IOException;
import java.util.List; import java.util.List;
import java.util.Random; import java.util.Random;
import org.apache.lucene.index.BinaryDocValues; import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafMetaData; import org.apache.lucene.index.LeafMetaData;
@ -229,6 +230,11 @@ public class QueryUtils {
return null; return null;
} }
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
return null;
}
@Override @Override
public TopDocs searchNearestVectors( public TopDocs searchNearestVectors(
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) { String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {