mirror of https://github.com/apache/lucene.git
Create new KnnByteVectorField and KnnVectorsReader#getByteVectorValues(String) (#12064)
This commit is contained in:
parent
e14327288e
commit
cc29102a24
|
@ -48,7 +48,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
private final Lucene90NeighborArray scratch;
|
private final Lucene90NeighborArray scratch;
|
||||||
|
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues<float[]> vectorValues;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final Lucene90BoundsChecker bound;
|
private final Lucene90BoundsChecker bound;
|
||||||
final Lucene90OnHeapHnswGraph hnsw;
|
final Lucene90OnHeapHnswGraph hnsw;
|
||||||
|
@ -57,7 +57,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
|
|
||||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||||
// colliding
|
// colliding
|
||||||
private final RandomAccessVectorValues buildVectors;
|
private final RandomAccessVectorValues<float[]> buildVectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||||
|
@ -72,7 +72,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
* to ensure repeatable construction.
|
* to ensure repeatable construction.
|
||||||
*/
|
*/
|
||||||
public Lucene90HnswGraphBuilder(
|
public Lucene90HnswGraphBuilder(
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<float[]> vectors,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
int maxConn,
|
int maxConn,
|
||||||
int beamWidth,
|
int beamWidth,
|
||||||
|
@ -103,7 +103,8 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||||
* accessor for the vectors
|
* accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
|
||||||
|
throws IOException {
|
||||||
if (vectors == vectorValues) {
|
if (vectors == vectorValues) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
@ -229,7 +230,7 @@ public final class Lucene90HnswGraphBuilder {
|
||||||
float[] candidate,
|
float[] candidate,
|
||||||
float score,
|
float score,
|
||||||
Lucene90NeighborArray neighbors,
|
Lucene90NeighborArray neighbors,
|
||||||
RandomAccessVectorValues vectorValues)
|
RandomAccessVectorValues<float[]> vectorValues)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
bound.set(score);
|
bound.set(score);
|
||||||
for (int i = 0; i < neighbors.size(); i++) {
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Map;
|
||||||
import java.util.SplittableRandom;
|
import java.util.SplittableRandom;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -232,6 +233,11 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return getOffHeapVectorValues(fieldEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -352,7 +358,8 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
static class OffHeapVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
final int dimension;
|
final int dimension;
|
||||||
final int[] ordToDoc;
|
final int[] ordToDoc;
|
||||||
|
@ -433,7 +440,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() {
|
public RandomAccessVectorValues<float[]> copy() {
|
||||||
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
|
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -443,17 +450,6 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
||||||
dataIn.readFloats(value, 0, value.length);
|
dataIn.readFloats(value, 0, value.length);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
readValue(targetOrd);
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readValue(int targetOrd) throws IOException {
|
|
||||||
dataIn.seek((long) targetOrd * byteSize);
|
|
||||||
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the nearest-neighbors graph from the index input */
|
/** Read the nearest-neighbors graph from the index input */
|
||||||
|
|
|
@ -74,7 +74,7 @@ public final class Lucene90OnHeapHnswGraph extends HnswGraph {
|
||||||
float[] query,
|
float[] query,
|
||||||
int topK,
|
int topK,
|
||||||
int numSeed,
|
int numSeed,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<float[]> vectors,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
HnswGraph graphValues,
|
HnswGraph graphValues,
|
||||||
Bits acceptOrds,
|
Bits acceptOrds,
|
||||||
|
|
|
@ -27,6 +27,7 @@ import java.util.Map;
|
||||||
import java.util.function.IntUnaryOperator;
|
import java.util.function.IntUnaryOperator;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -224,6 +225,11 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
return getOffHeapVectorValues(fieldEntry);
|
return getOffHeapVectorValues(fieldEntry);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -398,7 +404,8 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
static class OffHeapVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
private final int dimension;
|
private final int dimension;
|
||||||
private final int size;
|
private final int size;
|
||||||
|
@ -486,7 +493,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() {
|
public RandomAccessVectorValues<float[]> copy() {
|
||||||
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
|
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -496,17 +503,6 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
||||||
dataIn.readFloats(value, 0, value.length);
|
dataIn.readFloats(value, 0, value.length);
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
readValue(targetOrd);
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readValue(int targetOrd) throws IOException {
|
|
||||||
dataIn.seek((long) targetOrd * byteSize);
|
|
||||||
dataIn.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Read the nearest-neighbors graph from the index input */
|
/** Read the nearest-neighbors graph from the index input */
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -219,6 +220,11 @@ public final class Lucene92HnswVectorsReader extends KnnVectorsReader {
|
||||||
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
|
@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
abstract class OffHeapVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
protected final int dimension;
|
protected final int dimension;
|
||||||
protected final int size;
|
protected final int size;
|
||||||
|
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
readValue(targetOrd);
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readValue(int targetOrd) throws IOException {
|
|
||||||
slice.seek((long) targetOrd * byteSize);
|
|
||||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract int ordToDoc(int ord);
|
public abstract int ordToDoc(int ord);
|
||||||
|
|
||||||
static OffHeapVectorValues load(
|
static OffHeapVectorValues load(
|
||||||
|
@ -137,7 +127,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -210,7 +200,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -282,7 +272,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,11 +281,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int ordToDoc(int ord) {
|
public int ordToDoc(int ord) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -233,12 +234,31 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
throw new IllegalArgumentException(
|
||||||
return new ExpandingVectorValues(values);
|
"field=\""
|
||||||
} else {
|
+ field
|
||||||
return values;
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
}
|
}
|
||||||
|
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
|
}
|
||||||
|
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -292,7 +312,7 @@ public final class Lucene94HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
// bound k by total number of vectors to prevent oversizing data structures
|
// bound k by total number of vectors to prevent oversizing data structures
|
||||||
k = Math.min(k, fieldEntry.size());
|
k = Math.min(k, fieldEntry.size());
|
||||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||||
|
|
||||||
NeighborQueue results =
|
NeighborQueue results =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.backward_codecs.lucene94;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.RandomAccessInput;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
|
|
||||||
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
|
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
|
implements RandomAccessVectorValues<BytesRef> {
|
||||||
|
|
||||||
|
protected final int dimension;
|
||||||
|
protected final int size;
|
||||||
|
protected final IndexInput slice;
|
||||||
|
protected final BytesRef binaryValue;
|
||||||
|
protected final ByteBuffer byteBuffer;
|
||||||
|
protected final int byteSize;
|
||||||
|
|
||||||
|
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.size = size;
|
||||||
|
this.slice = slice;
|
||||||
|
this.byteSize = byteSize;
|
||||||
|
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||||
|
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||||
|
readValue(targetOrd);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readValue(int targetOrd) throws IOException {
|
||||||
|
slice.seek((long) targetOrd * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract int ordToDoc(int ord);
|
||||||
|
|
||||||
|
static OffHeapByteVectorValues load(
|
||||||
|
Lucene94HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||||
|
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||||
|
}
|
||||||
|
IndexInput bytesSlice =
|
||||||
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
int byteSize = fieldEntry.dimension;
|
||||||
|
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||||
|
return new DenseOffHeapVectorValues(
|
||||||
|
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||||
|
} else {
|
||||||
|
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||||
|
|
||||||
|
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
|
||||||
|
private int doc = -1;
|
||||||
|
|
||||||
|
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
|
super(dimension, size, slice, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
slice.seek((long) doc * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
assert docID() < target;
|
||||||
|
if (target >= size) {
|
||||||
|
return doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return doc = target;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return ord;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return acceptDocs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
private final DirectMonotonicReader ordToDoc;
|
||||||
|
private final IndexedDISI disi;
|
||||||
|
// dataIn was used to init a new IndexedDIS for #randomAccess()
|
||||||
|
private final IndexInput dataIn;
|
||||||
|
private final Lucene94HnswVectorsReader.FieldEntry fieldEntry;
|
||||||
|
|
||||||
|
public SparseOffHeapVectorValues(
|
||||||
|
Lucene94HnswVectorsReader.FieldEntry fieldEntry,
|
||||||
|
IndexInput dataIn,
|
||||||
|
IndexInput slice,
|
||||||
|
int byteSize)
|
||||||
|
throws IOException {
|
||||||
|
|
||||||
|
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||||
|
this.fieldEntry = fieldEntry;
|
||||||
|
final RandomAccessInput addressesData =
|
||||||
|
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||||
|
this.dataIn = dataIn;
|
||||||
|
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
|
||||||
|
this.disi =
|
||||||
|
new IndexedDISI(
|
||||||
|
dataIn,
|
||||||
|
fieldEntry.docsWithFieldOffset,
|
||||||
|
fieldEntry.docsWithFieldLength,
|
||||||
|
fieldEntry.jumpTableEntryCount,
|
||||||
|
fieldEntry.denseRankPower,
|
||||||
|
fieldEntry.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
slice.seek((long) (disi.index()) * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return disi.docID();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return disi.nextDoc();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
assert docID() < target;
|
||||||
|
return disi.advance(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return (int) ordToDoc.get(ord);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
if (acceptDocs == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new Bits() {
|
||||||
|
@Override
|
||||||
|
public boolean get(int index) {
|
||||||
|
return acceptDocs.get(ordToDoc(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int length() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
|
||||||
|
public EmptyOffHeapVectorValues(int dimension) {
|
||||||
|
super(dimension, 0, null, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int doc = -1;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return super.dimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
return doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -29,7 +29,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
abstract class OffHeapVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
protected final int dimension;
|
protected final int dimension;
|
||||||
protected final int size;
|
protected final int size;
|
||||||
|
@ -66,17 +67,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
readValue(targetOrd);
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readValue(int targetOrd) throws IOException {
|
|
||||||
slice.seek((long) targetOrd * byteSize);
|
|
||||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract int ordToDoc(int ord);
|
public abstract int ordToDoc(int ord);
|
||||||
|
|
||||||
static OffHeapVectorValues load(
|
static OffHeapVectorValues load(
|
||||||
|
@ -143,7 +133,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +209,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,7 +281,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,11 +290,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int ordToDoc(int ord) {
|
public int ordToDoc(int ord) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||||
|
|
||||||
private void writeGraph(
|
private void writeGraph(
|
||||||
IndexOutput graphData,
|
IndexOutput graphData,
|
||||||
RandomAccessVectorValues vectorValues,
|
RandomAccessVectorValues<float[]> vectorValues,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
long graphDataOffset,
|
long graphDataOffset,
|
||||||
long[] offsets,
|
long[] offsets,
|
||||||
|
|
|
@ -53,7 +53,7 @@ public final class Lucene91HnswGraphBuilder {
|
||||||
private final Lucene91NeighborArray scratch;
|
private final Lucene91NeighborArray scratch;
|
||||||
|
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final RandomAccessVectorValues vectorValues;
|
private final RandomAccessVectorValues<float[]> vectorValues;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final Lucene91BoundsChecker bound;
|
private final Lucene91BoundsChecker bound;
|
||||||
private final HnswGraphSearcher<float[]> graphSearcher;
|
private final HnswGraphSearcher<float[]> graphSearcher;
|
||||||
|
@ -64,7 +64,7 @@ public final class Lucene91HnswGraphBuilder {
|
||||||
|
|
||||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||||
// colliding
|
// colliding
|
||||||
private RandomAccessVectorValues buildVectors;
|
private RandomAccessVectorValues<float[]> buildVectors;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
* Reads all the vectors from a VectorValues, builds a graph connecting them by their dense
|
||||||
|
@ -79,7 +79,7 @@ public final class Lucene91HnswGraphBuilder {
|
||||||
* to ensure repeatable construction.
|
* to ensure repeatable construction.
|
||||||
*/
|
*/
|
||||||
public Lucene91HnswGraphBuilder(
|
public Lucene91HnswGraphBuilder(
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<float[]> vectors,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
int maxConn,
|
int maxConn,
|
||||||
int beamWidth,
|
int beamWidth,
|
||||||
|
@ -119,7 +119,8 @@ public final class Lucene91HnswGraphBuilder {
|
||||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
|
||||||
* accessor for the vectors
|
* accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues<float[]> vectors)
|
||||||
|
throws IOException {
|
||||||
if (vectors == vectorValues) {
|
if (vectors == vectorValues) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
@ -250,7 +251,7 @@ public final class Lucene91HnswGraphBuilder {
|
||||||
float[] candidate,
|
float[] candidate,
|
||||||
float score,
|
float score,
|
||||||
Lucene91NeighborArray neighbors,
|
Lucene91NeighborArray neighbors,
|
||||||
RandomAccessVectorValues vectorValues)
|
RandomAccessVectorValues<float[]> vectorValues)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
bound.set(score);
|
bound.set(score);
|
||||||
for (int i = 0; i < neighbors.size(); i++) {
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
|
|
|
@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private Lucene91OnHeapHnswGraph writeGraph(
|
private Lucene91OnHeapHnswGraph writeGraph(
|
||||||
RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
|
RandomAccessVectorValues<float[]> vectorValues, VectorSimilarityFunction similarityFunction)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
// build graph
|
// build graph
|
||||||
|
|
|
@ -268,13 +268,13 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
private OnHeapHnswGraph writeGraph(
|
private OnHeapHnswGraph writeGraph(
|
||||||
RandomAccessVectorValues vectorValues,
|
RandomAccessVectorValues<float[]> vectorValues,
|
||||||
VectorEncoding vectorEncoding,
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction)
|
VectorSimilarityFunction similarityFunction)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectorValues,
|
vectorValues,
|
||||||
vectorEncoding,
|
vectorEncoding,
|
||||||
|
|
|
@ -30,12 +30,14 @@ import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DocsWithFieldSet;
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
@ -379,8 +381,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
@Override
|
@Override
|
||||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
|
||||||
|
|
||||||
IndexOutput tempVectorData =
|
IndexOutput tempVectorData =
|
||||||
segmentWriteState.directory.createTempOutput(
|
segmentWriteState.directory.createTempOutput(
|
||||||
vectorData.getName(), "temp", segmentWriteState.context);
|
vectorData.getName(), "temp", segmentWriteState.context);
|
||||||
|
@ -389,7 +389,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
try {
|
try {
|
||||||
// write the vector data to a temporary file
|
// write the vector data to a temporary file
|
||||||
DocsWithFieldSet docsWithField =
|
DocsWithFieldSet docsWithField =
|
||||||
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeByteVectorData(
|
||||||
|
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||||
|
case FLOAT32 -> writeVectorData(
|
||||||
|
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||||
|
};
|
||||||
CodecUtil.writeFooter(tempVectorData);
|
CodecUtil.writeFooter(tempVectorData);
|
||||||
IOUtils.close(tempVectorData);
|
IOUtils.close(tempVectorData);
|
||||||
|
|
||||||
|
@ -405,23 +410,49 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
// we use Lucene94HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||||
// doesn't need to know docIds
|
// doesn't need to know docIds
|
||||||
// TODO: separate random access vector values from DocIdSetIterator?
|
// TODO: separate random access vector values from DocIdSetIterator?
|
||||||
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||||
OffHeapVectorValues offHeapVectors =
|
|
||||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
|
||||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
|
||||||
OnHeapHnswGraph graph = null;
|
OnHeapHnswGraph graph = null;
|
||||||
if (offHeapVectors.size() != 0) {
|
if (docsWithField.cardinality() != 0) {
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
graph =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> {
|
||||||
|
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||||
|
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
vectorDataInput,
|
||||||
|
byteSize);
|
||||||
|
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
offHeapVectors,
|
vectorValues,
|
||||||
fieldInfo.getVectorEncoding(),
|
fieldInfo.getVectorEncoding(),
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
M,
|
M,
|
||||||
beamWidth,
|
beamWidth,
|
||||||
HnswGraphBuilder.randSeed);
|
HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
graph = hnswGraphBuilder.build(offHeapVectors.copy());
|
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||||
|
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
vectorDataInput,
|
||||||
|
byteSize);
|
||||||
|
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectorValues,
|
||||||
|
fieldInfo.getVectorEncoding(),
|
||||||
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
|
M,
|
||||||
|
beamWidth,
|
||||||
|
HnswGraphBuilder.randSeed);
|
||||||
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
|
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||||
|
}
|
||||||
|
};
|
||||||
writeGraph(graph);
|
writeGraph(graph);
|
||||||
}
|
}
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
|
@ -554,16 +585,37 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||||
|
* vectors.
|
||||||
|
*/
|
||||||
|
private static DocsWithFieldSet writeByteVectorData(
|
||||||
|
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||||
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
|
for (int docV = byteVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = byteVectorValues.nextDoc()) {
|
||||||
|
// write vector
|
||||||
|
BytesRef binaryValue = byteVectorValues.binaryValue();
|
||||||
|
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||||
|
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
|
docsWithField.add(docV);
|
||||||
|
}
|
||||||
|
return docsWithField;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||||
*/
|
*/
|
||||||
private static DocsWithFieldSet writeVectorData(
|
private static DocsWithFieldSet writeVectorData(
|
||||||
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
IndexOutput output, VectorValues floatVectorValues) throws IOException {
|
||||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
for (int docV = floatVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = floatVectorValues.nextDoc()) {
|
||||||
// write vector
|
// write vector
|
||||||
BytesRef binaryValue = vectors.binaryValue();
|
BytesRef binaryValue = floatVectorValues.binaryValue();
|
||||||
assert binaryValue.length == vectors.dimension() * scalarSize;
|
assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
|
||||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
docsWithField.add(docV);
|
docsWithField.add(docV);
|
||||||
}
|
}
|
||||||
|
@ -580,7 +632,6 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
private final int dim;
|
private final int dim;
|
||||||
private final DocsWithFieldSet docsWithField;
|
private final DocsWithFieldSet docsWithField;
|
||||||
private final List<T> vectors;
|
private final List<T> vectors;
|
||||||
private final RAVectorValues<T> raVectorValues;
|
|
||||||
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||||
|
|
||||||
private int lastDocID = -1;
|
private int lastDocID = -1;
|
||||||
|
@ -593,8 +644,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
case BYTE -> new FieldWriter<BytesRef>(fieldInfo, M, beamWidth, infoStream) {
|
||||||
@Override
|
@Override
|
||||||
public BytesRef copyValue(BytesRef value) {
|
public BytesRef copyValue(BytesRef value) {
|
||||||
return new BytesRef(
|
return new BytesRef(ArrayUtil.copyOfSubArray(value.bytes, value.offset, dim));
|
||||||
ArrayUtil.copyOfSubArray(value.bytes, value.offset, value.offset + dim));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream) {
|
||||||
|
@ -613,9 +663,8 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
this.dim = fieldInfo.getVectorDimension();
|
this.dim = fieldInfo.getVectorDimension();
|
||||||
this.docsWithField = new DocsWithFieldSet();
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
vectors = new ArrayList<>();
|
vectors = new ArrayList<>();
|
||||||
raVectorValues = new RAVectorValues<>(vectors, dim);
|
RAVectorValues<T> raVectorValues = new RAVectorValues<>(vectors, dim);
|
||||||
hnswGraphBuilder =
|
hnswGraphBuilder =
|
||||||
(HnswGraphBuilder<T>)
|
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
raVectorValues,
|
raVectorValues,
|
||||||
fieldInfo.getVectorEncoding(),
|
fieldInfo.getVectorEncoding(),
|
||||||
|
@ -667,7 +716,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||||
private final List<T> vectors;
|
private final List<T> vectors;
|
||||||
private final int dim;
|
private final int dim;
|
||||||
|
|
||||||
|
@ -687,17 +736,12 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue(int targetOrd) throws IOException {
|
public T vectorValue(int targetOrd) throws IOException {
|
||||||
return (float[]) vectors.get(targetOrd);
|
return vectors.get(targetOrd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
public RAVectorValues<T> copy() throws IOException {
|
||||||
return (BytesRef) vectors.get(targetOrd);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import java.nio.charset.StandardCharsets;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
@ -143,6 +144,39 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
|
return new SimpleTextVectorValues(fieldEntry, bytesSlice, info.getVectorEncoding());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||||
|
if (info == null) {
|
||||||
|
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||||
|
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
int dimension = info.getVectorDimension();
|
||||||
|
if (dimension == 0) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"KNN vectors readers should not be called on fields that don't enable KNN vectors");
|
||||||
|
}
|
||||||
|
FieldEntry fieldEntry = fieldEntries.get(field);
|
||||||
|
if (fieldEntry == null) {
|
||||||
|
// mirror the handling in Lucene90VectorReader#getVectorValues
|
||||||
|
// needed to pass TestSimpleTextKnnVectorsFormat#testDeleteAllVectorDocs
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (dimension != fieldEntry.dimension) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Inconsistent vector dimension for field=\""
|
||||||
|
+ field
|
||||||
|
+ "\"; "
|
||||||
|
+ dimension
|
||||||
|
+ " != "
|
||||||
|
+ fieldEntry.dimension);
|
||||||
|
}
|
||||||
|
IndexInput bytesSlice =
|
||||||
|
dataIn.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
return new SimpleTextByteVectorValues(fieldEntry, bytesSlice);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
@ -187,7 +221,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, BytesRef target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
VectorValues values = getVectorValues(field);
|
ByteVectorValues values = getByteVectorValues(field);
|
||||||
if (target.length != values.dimension()) {
|
if (target.length != values.dimension()) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"vector query dimension: "
|
"vector query dimension: "
|
||||||
|
@ -213,7 +247,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
BytesRef vector = values.binaryValue();
|
BytesRef vector = values.vectorValue();
|
||||||
float score = vectorSimilarity.compare(vector, target);
|
float score = vectorSimilarity.compare(vector, target);
|
||||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||||
numVisited++;
|
numVisited++;
|
||||||
|
@ -301,7 +335,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class SimpleTextVectorValues extends VectorValues
|
private static class SimpleTextVectorValues extends VectorValues
|
||||||
implements RandomAccessVectorValues {
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||||
private final FieldEntry entry;
|
private final FieldEntry entry;
|
||||||
|
@ -356,7 +390,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() {
|
public RandomAccessVectorValues<float[]> copy() {
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -409,10 +443,99 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
||||||
public float[] vectorValue(int targetOrd) throws IOException {
|
public float[] vectorValue(int targetOrd) throws IOException {
|
||||||
return values[targetOrd];
|
return values[targetOrd];
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class SimpleTextByteVectorValues extends ByteVectorValues
|
||||||
|
implements RandomAccessVectorValues<BytesRef> {
|
||||||
|
|
||||||
|
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||||
|
private final FieldEntry entry;
|
||||||
|
private final IndexInput in;
|
||||||
|
private final BytesRef binaryValue;
|
||||||
|
private final byte[][] values;
|
||||||
|
|
||||||
|
int curOrd;
|
||||||
|
|
||||||
|
SimpleTextByteVectorValues(FieldEntry entry, IndexInput in) throws IOException {
|
||||||
|
this.entry = entry;
|
||||||
|
this.in = in;
|
||||||
|
values = new byte[entry.size()][entry.dimension];
|
||||||
|
binaryValue = new BytesRef(entry.dimension);
|
||||||
|
binaryValue.length = binaryValue.bytes.length;
|
||||||
|
curOrd = -1;
|
||||||
|
readAllVectors();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
public int dimension() {
|
||||||
throw new UnsupportedOperationException();
|
return entry.dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return entry.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() {
|
||||||
|
binaryValue.bytes = values[curOrd];
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() {
|
||||||
|
return this;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
if (curOrd == -1) {
|
||||||
|
return -1;
|
||||||
|
} else if (curOrd >= entry.size()) {
|
||||||
|
// when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID
|
||||||
|
// immediately afterward should also return NO_MORE_DOCS
|
||||||
|
// this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry.ordToDoc[curOrd];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
if (++curOrd < entry.size()) {
|
||||||
|
return docID();
|
||||||
|
}
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
return slowAdvance(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readAllVectors() throws IOException {
|
||||||
|
for (byte[] value : values) {
|
||||||
|
readVector(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readVector(byte[] value) throws IOException {
|
||||||
|
SimpleTextUtil.readLine(in, scratch);
|
||||||
|
// skip leading "[" and strip trailing "]"
|
||||||
|
String s = new BytesRef(scratch.bytes(), 1, scratch.length() - 2).utf8ToString();
|
||||||
|
String[] floatStrings = s.split(",");
|
||||||
|
assert floatStrings.length == value.length
|
||||||
|
: " read " + s + " when expecting " + value.length + " floats";
|
||||||
|
for (int i = 0; i < floatStrings.length; i++) {
|
||||||
|
value[i] = (byte) Float.parseFloat(floatStrings[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||||
|
binaryValue.bytes = values[curOrd];
|
||||||
|
return binaryValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.nio.ByteBuffer;
|
||||||
import java.nio.ByteOrder;
|
import java.nio.ByteOrder;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DocsWithFieldSet;
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
|
@ -85,6 +86,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||||
: vectorValues;
|
: vectorValues;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(
|
public TopDocs search(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
@ -202,6 +208,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||||
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
return MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() {}
|
public void checkIntegrity() {}
|
||||||
};
|
};
|
||||||
|
@ -228,7 +239,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void addValue(int docID, Object value) {
|
public void addValue(int docID, float[] value) {
|
||||||
if (docID == lastDocID) {
|
if (docID == lastDocID) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"VectorValuesField \""
|
"VectorValuesField \""
|
||||||
|
@ -236,25 +247,11 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
||||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
}
|
}
|
||||||
assert docID > lastDocID;
|
assert docID > lastDocID;
|
||||||
float[] vectorValue =
|
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
|
||||||
case FLOAT32 -> (float[]) value;
|
|
||||||
case BYTE -> bytesToFloats((BytesRef) value);
|
|
||||||
};
|
|
||||||
docsWithField.add(docID);
|
docsWithField.add(docID);
|
||||||
vectors.add(copyValue(vectorValue));
|
vectors.add(copyValue(value));
|
||||||
lastDocID = docID;
|
lastDocID = docID;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float[] bytesToFloats(BytesRef b) {
|
|
||||||
// This is used only by SimpleTextKnnVectorsWriter
|
|
||||||
float[] floats = new float[dim];
|
|
||||||
for (int i = 0; i < dim; i++) {
|
|
||||||
floats[i] = b.bytes[i + b.offset];
|
|
||||||
}
|
|
||||||
return floats;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] copyValue(float[] vectorValue) {
|
public float[] copyValue(float[] vectorValue) {
|
||||||
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
return ArrayUtil.copyOfSubArray(vectorValue, 0, dim);
|
||||||
|
|
|
@ -34,7 +34,7 @@ public abstract class KnnFieldVectorsWriter<T> implements Accountable {
|
||||||
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
* Add new docID with its vector value to the given field for indexing. Doc IDs must be added in
|
||||||
* increasing order.
|
* increasing order.
|
||||||
*/
|
*/
|
||||||
public abstract void addValue(int docID, Object vectorValue) throws IOException;
|
public abstract void addValue(int docID, T vectorValue) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Used to copy values being indexed to internal storage.
|
* Used to copy values being indexed to internal storage.
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.lucene.codecs;
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
@ -98,6 +99,11 @@ public abstract class KnnVectorsFormat implements NamedSPILoader.NamedSPI {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(
|
public TopDocs search(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.codecs;
|
||||||
|
|
||||||
import java.io.Closeable;
|
import java.io.Closeable;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
|
@ -51,6 +52,13 @@ public abstract class KnnVectorsReader implements Closeable, Accountable {
|
||||||
*/
|
*/
|
||||||
public abstract VectorValues getVectorValues(String field) throws IOException;
|
public abstract VectorValues getVectorValues(String field) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
|
||||||
|
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||||
|
* never {@code null}.
|
||||||
|
*/
|
||||||
|
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||||
|
|
|
@ -21,10 +21,12 @@ import java.io.Closeable;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DocIDMerger;
|
import org.apache.lucene.index.DocIDMerger;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.util.Accountable;
|
import org.apache.lucene.util.Accountable;
|
||||||
|
@ -44,13 +46,29 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
|
|
||||||
/** Write field for merging */
|
/** Write field for merging */
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public <T> void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
KnnFieldVectorsWriter<T> writer = (KnnFieldVectorsWriter<T>) addField(fieldInfo);
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
VectorValues mergedValues = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
case BYTE:
|
||||||
for (int doc = mergedValues.nextDoc();
|
KnnFieldVectorsWriter<BytesRef> byteWriter =
|
||||||
|
(KnnFieldVectorsWriter<BytesRef>) addField(fieldInfo);
|
||||||
|
ByteVectorValues mergedBytes =
|
||||||
|
MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||||
|
for (int doc = mergedBytes.nextDoc();
|
||||||
doc != DocIdSetIterator.NO_MORE_DOCS;
|
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
doc = mergedValues.nextDoc()) {
|
doc = mergedBytes.nextDoc()) {
|
||||||
writer.addValue(doc, mergedValues.vectorValue());
|
byteWriter.addValue(doc, mergedBytes.vectorValue());
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case FLOAT32:
|
||||||
|
KnnFieldVectorsWriter<float[]> floatWriter =
|
||||||
|
(KnnFieldVectorsWriter<float[]>) addField(fieldInfo);
|
||||||
|
VectorValues mergedFloats = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
||||||
|
for (int doc = mergedFloats.nextDoc();
|
||||||
|
doc != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
doc = mergedFloats.nextDoc()) {
|
||||||
|
floatWriter.addValue(doc, mergedFloats.vectorValue());
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -104,20 +122,34 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
|
private static class ByteVectorValuesSub extends DocIDMerger.Sub {
|
||||||
protected static class MergedVectorValues extends VectorValues {
|
|
||||||
private final List<VectorValuesSub> subs;
|
|
||||||
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
|
||||||
private final int size;
|
|
||||||
|
|
||||||
private int docId;
|
final ByteVectorValues values;
|
||||||
private VectorValuesSub current;
|
|
||||||
|
ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) {
|
||||||
|
super(docMap);
|
||||||
|
this.values = values;
|
||||||
|
assert values.docID() == -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return values.nextDoc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** View over multiple VectorValues supporting iterator-style access via DocIdMerger. */
|
||||||
|
protected static final class MergedVectorValues {
|
||||||
|
private MergedVectorValues() {}
|
||||||
|
|
||||||
/** Returns a merged view over all the segment's {@link VectorValues}. */
|
/** Returns a merged view over all the segment's {@link VectorValues}. */
|
||||||
public static MergedVectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
public static VectorValues mergeVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||||
|
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as FLOAT32");
|
||||||
|
}
|
||||||
List<VectorValuesSub> subs = new ArrayList<>();
|
List<VectorValuesSub> subs = new ArrayList<>();
|
||||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||||
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||||
|
@ -128,10 +160,39 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return new MergedVectorValues(subs, mergeState);
|
return new MergedFloat32VectorValues(subs, mergeState);
|
||||||
}
|
}
|
||||||
|
|
||||||
private MergedVectorValues(List<VectorValuesSub> subs, MergeState mergeState)
|
/** Returns a merged view over all the segment's {@link ByteVectorValues}. */
|
||||||
|
public static ByteVectorValues mergeByteVectorValues(FieldInfo fieldInfo, MergeState mergeState)
|
||||||
|
throws IOException {
|
||||||
|
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||||
|
if (fieldInfo.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||||
|
throw new UnsupportedOperationException(
|
||||||
|
"Cannot merge vectors encoded as [" + fieldInfo.getVectorEncoding() + "] as BYTE");
|
||||||
|
}
|
||||||
|
List<ByteVectorValuesSub> subs = new ArrayList<>();
|
||||||
|
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||||
|
KnnVectorsReader knnVectorsReader = mergeState.knnVectorsReaders[i];
|
||||||
|
if (knnVectorsReader != null) {
|
||||||
|
ByteVectorValues values = knnVectorsReader.getByteVectorValues(fieldInfo.name);
|
||||||
|
if (values != null) {
|
||||||
|
subs.add(new ByteVectorValuesSub(mergeState.docMaps[i], values));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new MergedByteVectorValues(subs, mergeState);
|
||||||
|
}
|
||||||
|
|
||||||
|
static class MergedFloat32VectorValues extends VectorValues {
|
||||||
|
private final List<VectorValuesSub> subs;
|
||||||
|
private final DocIDMerger<VectorValuesSub> docIdMerger;
|
||||||
|
private final int size;
|
||||||
|
|
||||||
|
private int docId;
|
||||||
|
VectorValuesSub current;
|
||||||
|
|
||||||
|
private MergedFloat32VectorValues(List<VectorValuesSub> subs, MergeState mergeState)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.subs = subs;
|
this.subs = subs;
|
||||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||||
|
@ -184,4 +245,62 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
||||||
return subs.get(0).values.dimension();
|
return subs.get(0).values.dimension();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static class MergedByteVectorValues extends ByteVectorValues {
|
||||||
|
private final List<ByteVectorValuesSub> subs;
|
||||||
|
private final DocIDMerger<ByteVectorValuesSub> docIdMerger;
|
||||||
|
private final int size;
|
||||||
|
|
||||||
|
private int docId;
|
||||||
|
ByteVectorValuesSub current;
|
||||||
|
|
||||||
|
private MergedByteVectorValues(List<ByteVectorValuesSub> subs, MergeState mergeState)
|
||||||
|
throws IOException {
|
||||||
|
this.subs = subs;
|
||||||
|
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||||
|
int totalSize = 0;
|
||||||
|
for (ByteVectorValuesSub sub : subs) {
|
||||||
|
totalSize += sub.values.size();
|
||||||
|
}
|
||||||
|
size = totalSize;
|
||||||
|
docId = -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
return current.values.vectorValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return docId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
current = docIdMerger.next();
|
||||||
|
if (current == null) {
|
||||||
|
docId = NO_MORE_DOCS;
|
||||||
|
} else {
|
||||||
|
docId = current.mappedDocID;
|
||||||
|
}
|
||||||
|
return docId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return subs.get(0).values.dimension();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,47 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.lucene.codecs.lucene95;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import org.apache.lucene.index.FilterVectorValues;
|
|
||||||
import org.apache.lucene.index.VectorValues;
|
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
|
|
||||||
/** reads from byte-encoded data */
|
|
||||||
class ExpandingVectorValues extends FilterVectorValues {
|
|
||||||
|
|
||||||
private final float[] value;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param in the wrapped values
|
|
||||||
*/
|
|
||||||
protected ExpandingVectorValues(VectorValues in) {
|
|
||||||
super(in);
|
|
||||||
value = new float[in.dimension()];
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public float[] vectorValue() throws IOException {
|
|
||||||
BytesRef binaryValue = binaryValue();
|
|
||||||
byte[] bytes = binaryValue.bytes;
|
|
||||||
for (int i = 0, j = binaryValue.offset; i < value.length; i++, j++) {
|
|
||||||
value[i] = bytes[j];
|
|
||||||
}
|
|
||||||
return value;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -25,6 +25,7 @@ import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -238,12 +239,31 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
VectorValues values = OffHeapVectorValues.load(fieldEntry, vectorData);
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
if (fieldEntry.vectorEncoding == VectorEncoding.BYTE) {
|
throw new IllegalArgumentException(
|
||||||
return new ExpandingVectorValues(values);
|
"field=\""
|
||||||
} else {
|
+ field
|
||||||
return values;
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
}
|
}
|
||||||
|
return OffHeapVectorValues.load(fieldEntry, vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
|
}
|
||||||
|
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -303,7 +323,7 @@ public final class Lucene95HnswVectorsReader extends KnnVectorsReader {
|
||||||
|
|
||||||
// bound k by total number of vectors to prevent oversizing data structures
|
// bound k by total number of vectors to prevent oversizing data structures
|
||||||
k = Math.min(k, fieldEntry.size());
|
k = Math.min(k, fieldEntry.size());
|
||||||
OffHeapVectorValues vectorValues = OffHeapVectorValues.load(fieldEntry, vectorData);
|
OffHeapByteVectorValues vectorValues = OffHeapByteVectorValues.load(fieldEntry, vectorData);
|
||||||
|
|
||||||
NeighborQueue results =
|
NeighborQueue results =
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
|
|
|
@ -391,17 +391,21 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
@Override
|
@Override
|
||||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
VectorValues vectors = MergedVectorValues.mergeVectorValues(fieldInfo, mergeState);
|
|
||||||
|
|
||||||
IndexOutput tempVectorData =
|
IndexOutput tempVectorData =
|
||||||
segmentWriteState.directory.createTempOutput(
|
segmentWriteState.directory.createTempOutput(
|
||||||
vectorData.getName(), "temp", segmentWriteState.context);
|
vectorData.getName(), "temp", segmentWriteState.context);
|
||||||
IndexInput vectorDataInput = null;
|
IndexInput vectorDataInput = null;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
|
// write the vector data to a temporary file
|
||||||
// write the vector data to a temporary file
|
// write the vector data to a temporary file
|
||||||
DocsWithFieldSet docsWithField =
|
DocsWithFieldSet docsWithField =
|
||||||
writeVectorData(tempVectorData, vectors, fieldInfo.getVectorEncoding().byteSize);
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeByteVectorData(
|
||||||
|
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||||
|
case FLOAT32 -> writeVectorData(
|
||||||
|
tempVectorData, MergedVectorValues.mergeVectorValues(fieldInfo, mergeState));
|
||||||
|
};
|
||||||
CodecUtil.writeFooter(tempVectorData);
|
CodecUtil.writeFooter(tempVectorData);
|
||||||
IOUtils.close(tempVectorData);
|
IOUtils.close(tempVectorData);
|
||||||
|
|
||||||
|
@ -417,24 +421,50 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
// we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
// we use Lucene95HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||||
// doesn't need to know docIds
|
// doesn't need to know docIds
|
||||||
// TODO: separate random access vector values from DocIdSetIterator?
|
// TODO: separate random access vector values from DocIdSetIterator?
|
||||||
int byteSize = vectors.dimension() * fieldInfo.getVectorEncoding().byteSize;
|
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||||
OffHeapVectorValues offHeapVectors =
|
|
||||||
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
|
||||||
vectors.dimension(), docsWithField.cardinality(), vectorDataInput, byteSize);
|
|
||||||
OnHeapHnswGraph graph = null;
|
OnHeapHnswGraph graph = null;
|
||||||
int[][] vectorIndexNodeOffsets = null;
|
int[][] vectorIndexNodeOffsets = null;
|
||||||
if (offHeapVectors.size() != 0) {
|
if (docsWithField.cardinality() != 0) {
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphBuilder<?> hnswGraphBuilder =
|
graph =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> {
|
||||||
|
OffHeapByteVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||||
|
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
vectorDataInput,
|
||||||
|
byteSize);
|
||||||
|
HnswGraphBuilder<BytesRef> hnswGraphBuilder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
offHeapVectors,
|
vectorValues,
|
||||||
fieldInfo.getVectorEncoding(),
|
fieldInfo.getVectorEncoding(),
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
M,
|
M,
|
||||||
beamWidth,
|
beamWidth,
|
||||||
HnswGraphBuilder.randSeed);
|
HnswGraphBuilder.randSeed);
|
||||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
graph = hnswGraphBuilder.build(offHeapVectors.copy());
|
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
OffHeapVectorValues.DenseOffHeapVectorValues vectorValues =
|
||||||
|
new OffHeapVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
vectorDataInput,
|
||||||
|
byteSize);
|
||||||
|
HnswGraphBuilder<float[]> hnswGraphBuilder =
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectorValues,
|
||||||
|
fieldInfo.getVectorEncoding(),
|
||||||
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
|
M,
|
||||||
|
beamWidth,
|
||||||
|
HnswGraphBuilder.randSeed);
|
||||||
|
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||||
|
yield hnswGraphBuilder.build(vectorValues.copy());
|
||||||
|
}
|
||||||
|
};
|
||||||
vectorIndexNodeOffsets = writeGraph(graph);
|
vectorIndexNodeOffsets = writeGraph(graph);
|
||||||
}
|
}
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
|
@ -605,16 +635,37 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||||
|
* vectors.
|
||||||
|
*/
|
||||||
|
private static DocsWithFieldSet writeByteVectorData(
|
||||||
|
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||||
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
|
for (int docV = byteVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = byteVectorValues.nextDoc()) {
|
||||||
|
// write vector
|
||||||
|
BytesRef binaryValue = byteVectorValues.binaryValue();
|
||||||
|
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||||
|
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
|
docsWithField.add(docV);
|
||||||
|
}
|
||||||
|
return docsWithField;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||||
*/
|
*/
|
||||||
private static DocsWithFieldSet writeVectorData(
|
private static DocsWithFieldSet writeVectorData(
|
||||||
IndexOutput output, VectorValues vectors, int scalarSize) throws IOException {
|
IndexOutput output, VectorValues floatVectorValues) throws IOException {
|
||||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) {
|
for (int docV = floatVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = floatVectorValues.nextDoc()) {
|
||||||
// write vector
|
// write vector
|
||||||
BytesRef binaryValue = vectors.binaryValue();
|
BytesRef binaryValue = floatVectorValues.binaryValue();
|
||||||
assert binaryValue.length == vectors.dimension() * scalarSize;
|
assert binaryValue.length == floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize;
|
||||||
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
output.writeBytes(binaryValue.bytes, binaryValue.offset, binaryValue.length);
|
||||||
docsWithField.add(docV);
|
docsWithField.add(docV);
|
||||||
}
|
}
|
||||||
|
@ -631,7 +682,6 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
private final int dim;
|
private final int dim;
|
||||||
private final DocsWithFieldSet docsWithField;
|
private final DocsWithFieldSet docsWithField;
|
||||||
private final List<T> vectors;
|
private final List<T> vectors;
|
||||||
private final RAVectorValues<T> raVectorValues;
|
|
||||||
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
private final HnswGraphBuilder<T> hnswGraphBuilder;
|
||||||
|
|
||||||
private int lastDocID = -1;
|
private int lastDocID = -1;
|
||||||
|
@ -657,18 +707,15 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.fieldInfo = fieldInfo;
|
this.fieldInfo = fieldInfo;
|
||||||
this.dim = fieldInfo.getVectorDimension();
|
this.dim = fieldInfo.getVectorDimension();
|
||||||
this.docsWithField = new DocsWithFieldSet();
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
vectors = new ArrayList<>();
|
vectors = new ArrayList<>();
|
||||||
raVectorValues = new RAVectorValues<>(vectors, dim);
|
|
||||||
hnswGraphBuilder =
|
hnswGraphBuilder =
|
||||||
(HnswGraphBuilder<T>)
|
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
raVectorValues,
|
new RAVectorValues<>(vectors, dim),
|
||||||
fieldInfo.getVectorEncoding(),
|
fieldInfo.getVectorEncoding(),
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
M,
|
M,
|
||||||
|
@ -678,15 +725,13 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@SuppressWarnings("unchecked")
|
public void addValue(int docID, T vectorValue) throws IOException {
|
||||||
public void addValue(int docID, Object value) throws IOException {
|
|
||||||
if (docID == lastDocID) {
|
if (docID == lastDocID) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"VectorValuesField \""
|
"VectorValuesField \""
|
||||||
+ fieldInfo.name
|
+ fieldInfo.name
|
||||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
}
|
}
|
||||||
T vectorValue = (T) value;
|
|
||||||
assert docID > lastDocID;
|
assert docID > lastDocID;
|
||||||
docsWithField.add(docID);
|
docsWithField.add(docID);
|
||||||
vectors.add(copyValue(vectorValue));
|
vectors.add(copyValue(vectorValue));
|
||||||
|
@ -719,7 +764,7 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class RAVectorValues<T> implements RandomAccessVectorValues {
|
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||||
private final List<T> vectors;
|
private final List<T> vectors;
|
||||||
private final int dim;
|
private final int dim;
|
||||||
|
|
||||||
|
@ -739,17 +784,12 @@ public final class Lucene95HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue(int targetOrd) throws IOException {
|
public T vectorValue(int targetOrd) throws IOException {
|
||||||
return (float[]) vectors.get(targetOrd);
|
return vectors.get(targetOrd);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
public RandomAccessVectorValues<T> copy() throws IOException {
|
||||||
return (BytesRef) vectors.get(targetOrd);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
|
||||||
return this;
|
return this;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,283 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene95;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.RandomAccessInput;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
|
|
||||||
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
|
abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
|
implements RandomAccessVectorValues<BytesRef> {
|
||||||
|
|
||||||
|
protected final int dimension;
|
||||||
|
protected final int size;
|
||||||
|
protected final IndexInput slice;
|
||||||
|
protected final BytesRef binaryValue;
|
||||||
|
protected final ByteBuffer byteBuffer;
|
||||||
|
protected final int byteSize;
|
||||||
|
|
||||||
|
OffHeapByteVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.size = size;
|
||||||
|
this.slice = slice;
|
||||||
|
this.byteSize = byteSize;
|
||||||
|
byteBuffer = ByteBuffer.allocate(byteSize);
|
||||||
|
binaryValue = new BytesRef(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||||
|
readValue(targetOrd);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readValue(int targetOrd) throws IOException {
|
||||||
|
slice.seek((long) targetOrd * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
public abstract int ordToDoc(int ord);
|
||||||
|
|
||||||
|
static OffHeapByteVectorValues load(
|
||||||
|
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||||
|
if (fieldEntry.docsWithFieldOffset == -2 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||||
|
}
|
||||||
|
IndexInput bytesSlice =
|
||||||
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
|
int byteSize = fieldEntry.dimension;
|
||||||
|
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||||
|
return new DenseOffHeapVectorValues(
|
||||||
|
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||||
|
} else {
|
||||||
|
return new SparseOffHeapVectorValues(fieldEntry, vectorData, bytesSlice, byteSize);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||||
|
|
||||||
|
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
|
||||||
|
private int doc = -1;
|
||||||
|
|
||||||
|
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice, int byteSize) {
|
||||||
|
super(dimension, size, slice, byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
slice.seek((long) doc * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
assert docID() < target;
|
||||||
|
if (target >= size) {
|
||||||
|
return doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return doc = target;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return ord;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return acceptDocs;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class SparseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
private final DirectMonotonicReader ordToDoc;
|
||||||
|
private final IndexedDISI disi;
|
||||||
|
// dataIn was used to init a new IndexedDIS for #randomAccess()
|
||||||
|
private final IndexInput dataIn;
|
||||||
|
private final Lucene95HnswVectorsReader.FieldEntry fieldEntry;
|
||||||
|
|
||||||
|
public SparseOffHeapVectorValues(
|
||||||
|
Lucene95HnswVectorsReader.FieldEntry fieldEntry,
|
||||||
|
IndexInput dataIn,
|
||||||
|
IndexInput slice,
|
||||||
|
int byteSize)
|
||||||
|
throws IOException {
|
||||||
|
|
||||||
|
super(fieldEntry.dimension, fieldEntry.size, slice, byteSize);
|
||||||
|
this.fieldEntry = fieldEntry;
|
||||||
|
final RandomAccessInput addressesData =
|
||||||
|
dataIn.randomAccessSlice(fieldEntry.addressesOffset, fieldEntry.addressesLength);
|
||||||
|
this.dataIn = dataIn;
|
||||||
|
this.ordToDoc = DirectMonotonicReader.getInstance(fieldEntry.meta, addressesData);
|
||||||
|
this.disi =
|
||||||
|
new IndexedDISI(
|
||||||
|
dataIn,
|
||||||
|
fieldEntry.docsWithFieldOffset,
|
||||||
|
fieldEntry.docsWithFieldLength,
|
||||||
|
fieldEntry.jumpTableEntryCount,
|
||||||
|
fieldEntry.denseRankPower,
|
||||||
|
fieldEntry.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
slice.seek((long) (disi.index()) * byteSize);
|
||||||
|
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize, false);
|
||||||
|
return binaryValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return disi.docID();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return disi.nextDoc();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
assert docID() < target;
|
||||||
|
return disi.advance(target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return (int) ordToDoc.get(ord);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
if (acceptDocs == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new Bits() {
|
||||||
|
@Override
|
||||||
|
public boolean get(int index) {
|
||||||
|
return acceptDocs.get(ordToDoc(index));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int length() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
|
||||||
|
public EmptyOffHeapVectorValues(int dimension) {
|
||||||
|
super(dimension, 0, null, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
private int doc = -1;
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return super.dimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
return doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomAccessVectorValues<BytesRef> copy() throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int targetOrd) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ package org.apache.lucene.codecs.lucene95;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
import java.nio.ByteBuffer;
|
||||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.store.RandomAccessInput;
|
import org.apache.lucene.store.RandomAccessInput;
|
||||||
|
@ -29,7 +30,8 @@ import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||||
|
|
||||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
abstract class OffHeapVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
|
|
||||||
protected final int dimension;
|
protected final int dimension;
|
||||||
protected final int size;
|
protected final int size;
|
||||||
|
@ -66,31 +68,17 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
readValue(targetOrd);
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void readValue(int targetOrd) throws IOException {
|
|
||||||
slice.seek((long) targetOrd * byteSize);
|
|
||||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), byteSize);
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract int ordToDoc(int ord);
|
public abstract int ordToDoc(int ord);
|
||||||
|
|
||||||
static OffHeapVectorValues load(
|
static OffHeapVectorValues load(
|
||||||
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
Lucene95HnswVectorsReader.FieldEntry fieldEntry, IndexInput vectorData) throws IOException {
|
||||||
if (fieldEntry.docsWithFieldOffset == -2) {
|
if (fieldEntry.docsWithFieldOffset == -2
|
||||||
|
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
return new EmptyOffHeapVectorValues(fieldEntry.dimension);
|
||||||
}
|
}
|
||||||
IndexInput bytesSlice =
|
IndexInput bytesSlice =
|
||||||
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
vectorData.slice("vector-data", fieldEntry.vectorDataOffset, fieldEntry.vectorDataLength);
|
||||||
int byteSize =
|
int byteSize = fieldEntry.dimension * Float.BYTES;
|
||||||
switch (fieldEntry.vectorEncoding) {
|
|
||||||
case BYTE -> fieldEntry.dimension;
|
|
||||||
case FLOAT32 -> fieldEntry.dimension * Float.BYTES;
|
|
||||||
};
|
|
||||||
if (fieldEntry.docsWithFieldOffset == -1) {
|
if (fieldEntry.docsWithFieldOffset == -1) {
|
||||||
return new DenseOffHeapVectorValues(
|
return new DenseOffHeapVectorValues(
|
||||||
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
fieldEntry.dimension, fieldEntry.size, bytesSlice, byteSize);
|
||||||
|
@ -143,7 +131,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,7 +207,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -291,7 +279,7 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomAccessVectorValues copy() throws IOException {
|
public RandomAccessVectorValues<float[]> copy() throws IOException {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -300,11 +288,6 @@ abstract class OffHeapVectorValues extends VectorValues implements RandomAccessV
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int ordToDoc(int ord) {
|
public int ordToDoc(int ord) {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
|
|
|
@ -27,6 +27,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentReadState;
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
@ -255,6 +256,16 @@ public abstract class PerFieldKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
KnnVectorsReader knnVectorsReader = fields.get(field);
|
||||||
|
if (knnVectorsReader == null) {
|
||||||
|
return null;
|
||||||
|
} else {
|
||||||
|
return knnVectorsReader.getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
|
@ -0,0 +1,163 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.document;
|
||||||
|
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A field that contains a single byte numeric vector (or none) for each document. Vectors are dense
|
||||||
|
* - that is, every dimension of a vector contains an explicit value, stored packed into an array
|
||||||
|
* (of type byte[]) whose length is the vector dimension. Values can be retrieved using {@link
|
||||||
|
* ByteVectorValues}, which is a forward-only docID-based iterator and also offers random-access by
|
||||||
|
* dense ordinal (not docId). {@link VectorSimilarityFunction} may be used to compare vectors at
|
||||||
|
* query time (for example as part of result ranking). A KnnByteVectorField may be associated with a
|
||||||
|
* search similarity function defining the metric used for nearest-neighbor search among vectors of
|
||||||
|
* that field.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public class KnnByteVectorField extends Field {
|
||||||
|
|
||||||
|
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
||||||
|
if (v == null) {
|
||||||
|
throw new IllegalArgumentException("vector value must not be null");
|
||||||
|
}
|
||||||
|
int dimension = v.length;
|
||||||
|
if (dimension == 0) {
|
||||||
|
throw new IllegalArgumentException("cannot index an empty vector");
|
||||||
|
}
|
||||||
|
if (dimension > ByteVectorValues.MAX_DIMENSIONS) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"cannot index vectors with dimension greater than " + ByteVectorValues.MAX_DIMENSIONS);
|
||||||
|
}
|
||||||
|
if (similarityFunction == null) {
|
||||||
|
throw new IllegalArgumentException("similarity function must not be null");
|
||||||
|
}
|
||||||
|
FieldType type = new FieldType();
|
||||||
|
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
|
||||||
|
type.freeze();
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a new vector query for the provided field targeting the byte vector
|
||||||
|
*
|
||||||
|
* @param field The field to query
|
||||||
|
* @param queryVector The byte vector target
|
||||||
|
* @param k The number of nearest neighbors to gather
|
||||||
|
* @return A new vector query
|
||||||
|
*/
|
||||||
|
public static Query newVectorQuery(String field, BytesRef queryVector, int k) {
|
||||||
|
return new KnnByteVectorQuery(field, queryVector, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A convenience method for creating a vector field type.
|
||||||
|
*
|
||||||
|
* @param dimension dimension of vectors
|
||||||
|
* @param similarityFunction a function defining vector proximity.
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
||||||
|
*/
|
||||||
|
public static FieldType createFieldType(
|
||||||
|
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||||
|
FieldType type = new FieldType();
|
||||||
|
type.setVectorAttributes(dimension, VectorEncoding.BYTE, similarityFunction);
|
||||||
|
type.freeze();
|
||||||
|
return type;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||||
|
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||||
|
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
||||||
|
* be constant-length.
|
||||||
|
*
|
||||||
|
* @param name field name
|
||||||
|
* @param vector value
|
||||||
|
* @param similarityFunction a function defining vector proximity.
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||||
|
* dimension > 1024.
|
||||||
|
*/
|
||||||
|
public KnnByteVectorField(
|
||||||
|
String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||||
|
super(name, createType(vector, similarityFunction));
|
||||||
|
fieldsData = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||||
|
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||||
|
* the same dimension and similarity function.
|
||||||
|
*
|
||||||
|
* @param name field name
|
||||||
|
* @param vector value
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||||
|
* dimension > 1024.
|
||||||
|
*/
|
||||||
|
public KnnByteVectorField(String name, BytesRef vector) {
|
||||||
|
this(name, vector, VectorSimilarityFunction.EUCLIDEAN);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||||
|
* no value. Vectors of a single field share the same dimension and similarity function.
|
||||||
|
*
|
||||||
|
* @param name field name
|
||||||
|
* @param vector value
|
||||||
|
* @param fieldType field type
|
||||||
|
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
||||||
|
* dimension > 1024.
|
||||||
|
*/
|
||||||
|
public KnnByteVectorField(String name, BytesRef vector, FieldType fieldType) {
|
||||||
|
super(name, fieldType);
|
||||||
|
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Attempt to create a vector for field "
|
||||||
|
+ name
|
||||||
|
+ " using byte[] but the field encoding is "
|
||||||
|
+ fieldType.vectorEncoding());
|
||||||
|
}
|
||||||
|
fieldsData = vector;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Return the vector value of this field */
|
||||||
|
public BytesRef vectorValue() {
|
||||||
|
return (BytesRef) fieldsData;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set the vector value of this field
|
||||||
|
*
|
||||||
|
* @param value the value to set; must not be null, and length must match the field type
|
||||||
|
*/
|
||||||
|
public void setVectorValue(BytesRef value) {
|
||||||
|
if (value == null) {
|
||||||
|
throw new IllegalArgumentException("value must not be null");
|
||||||
|
}
|
||||||
|
if (value.length != type.vectorDimension()) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"value length " + value.length + " must match field dimension " + type.vectorDimension());
|
||||||
|
}
|
||||||
|
fieldsData = value;
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,7 +20,8 @@ package org.apache.lucene.document;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.search.KnnVectorQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -41,18 +42,7 @@ public class KnnVectorField extends Field {
|
||||||
if (v == null) {
|
if (v == null) {
|
||||||
throw new IllegalArgumentException("vector value must not be null");
|
throw new IllegalArgumentException("vector value must not be null");
|
||||||
}
|
}
|
||||||
return createType(v.length, VectorEncoding.FLOAT32, similarityFunction);
|
int dimension = v.length;
|
||||||
}
|
|
||||||
|
|
||||||
private static FieldType createType(BytesRef v, VectorSimilarityFunction similarityFunction) {
|
|
||||||
if (v == null) {
|
|
||||||
throw new IllegalArgumentException("vector value must not be null");
|
|
||||||
}
|
|
||||||
return createType(v.length, VectorEncoding.BYTE, similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static FieldType createType(
|
|
||||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
|
||||||
if (dimension == 0) {
|
if (dimension == 0) {
|
||||||
throw new IllegalArgumentException("cannot index an empty vector");
|
throw new IllegalArgumentException("cannot index an empty vector");
|
||||||
}
|
}
|
||||||
|
@ -64,13 +54,13 @@ public class KnnVectorField extends Field {
|
||||||
throw new IllegalArgumentException("similarity function must not be null");
|
throw new IllegalArgumentException("similarity function must not be null");
|
||||||
}
|
}
|
||||||
FieldType type = new FieldType();
|
FieldType type = new FieldType();
|
||||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||||
type.freeze();
|
type.freeze();
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A convenience method for creating a vector field type with the default FLOAT32 encoding.
|
* A convenience method for creating a vector field type.
|
||||||
*
|
*
|
||||||
* @param dimension dimension of vectors
|
* @param dimension dimension of vectors
|
||||||
* @param similarityFunction a function defining vector proximity.
|
* @param similarityFunction a function defining vector proximity.
|
||||||
|
@ -78,23 +68,22 @@ public class KnnVectorField extends Field {
|
||||||
*/
|
*/
|
||||||
public static FieldType createFieldType(
|
public static FieldType createFieldType(
|
||||||
int dimension, VectorSimilarityFunction similarityFunction) {
|
int dimension, VectorSimilarityFunction similarityFunction) {
|
||||||
return createFieldType(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
FieldType type = new FieldType();
|
||||||
|
type.setVectorAttributes(dimension, VectorEncoding.FLOAT32, similarityFunction);
|
||||||
|
type.freeze();
|
||||||
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A convenience method for creating a vector field type.
|
* Create a new vector query for the provided field targeting the float vector
|
||||||
*
|
*
|
||||||
* @param dimension dimension of vectors
|
* @param field The field to query
|
||||||
* @param vectorEncoding the encoding of the scalar values
|
* @param queryVector The float vector target
|
||||||
* @param similarityFunction a function defining vector proximity.
|
* @param k The number of nearest neighbors to gather
|
||||||
* @throws IllegalArgumentException if any parameter is null, or has dimension > 1024.
|
* @return A new vector query
|
||||||
*/
|
*/
|
||||||
public static FieldType createFieldType(
|
public static Query newVectorQuery(String field, float[] queryVector, int k) {
|
||||||
int dimension, VectorEncoding vectorEncoding, VectorSimilarityFunction similarityFunction) {
|
return new KnnVectorQuery(field, queryVector, k);
|
||||||
FieldType type = new FieldType();
|
|
||||||
type.setVectorAttributes(dimension, vectorEncoding, similarityFunction);
|
|
||||||
type.freeze();
|
|
||||||
return type;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -114,23 +103,6 @@ public class KnnVectorField extends Field {
|
||||||
fieldsData = vector;
|
fieldsData = vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
|
||||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
|
||||||
* some vector similarities (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to
|
|
||||||
* be constant-length.
|
|
||||||
*
|
|
||||||
* @param name field name
|
|
||||||
* @param vector value
|
|
||||||
* @param similarityFunction a function defining vector proximity.
|
|
||||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
|
||||||
* dimension > 1024.
|
|
||||||
*/
|
|
||||||
public KnnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
|
||||||
super(name, createType(vector, similarityFunction));
|
|
||||||
fieldsData = vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
* Creates a numeric vector field with the default EUCLIDEAN_HNSW (L2) similarity. Fields are
|
||||||
* single-valued: each document has either one value or no value. Vectors of a single field share
|
* single-valued: each document has either one value or no value. Vectors of a single field share
|
||||||
|
@ -167,28 +139,6 @@ public class KnnVectorField extends Field {
|
||||||
fieldsData = vector;
|
fieldsData = vector;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
|
||||||
* no value. Vectors of a single field share the same dimension and similarity function.
|
|
||||||
*
|
|
||||||
* @param name field name
|
|
||||||
* @param vector value
|
|
||||||
* @param fieldType field type
|
|
||||||
* @throws IllegalArgumentException if any parameter is null, or the vector is empty or has
|
|
||||||
* dimension > 1024.
|
|
||||||
*/
|
|
||||||
public KnnVectorField(String name, BytesRef vector, FieldType fieldType) {
|
|
||||||
super(name, fieldType);
|
|
||||||
if (fieldType.vectorEncoding() != VectorEncoding.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Attempt to create a vector for field "
|
|
||||||
+ name
|
|
||||||
+ " using BytesRef but the field encoding is "
|
|
||||||
+ fieldType.vectorEncoding());
|
|
||||||
}
|
|
||||||
fieldsData = vector;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Return the vector value of this field */
|
/** Return the vector value of this field */
|
||||||
public float[] vectorValue() {
|
public float[] vectorValue() {
|
||||||
return (float[]) fieldsData;
|
return (float[]) fieldsData;
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
package org.apache.lucene.index;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This class provides access to per-document floating point vector values indexed as {@link
|
||||||
|
* KnnByteVectorField}.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class ByteVectorValues extends DocIdSetIterator {
|
||||||
|
|
||||||
|
/** The maximum length of a vector */
|
||||||
|
public static final int MAX_DIMENSIONS = 1024;
|
||||||
|
|
||||||
|
/** Sole constructor */
|
||||||
|
protected ByteVectorValues() {}
|
||||||
|
|
||||||
|
/** Return the dimension of the vectors */
|
||||||
|
public abstract int dimension();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the number of vectors for this field.
|
||||||
|
*
|
||||||
|
* @return the number of vectors returned by this iterator
|
||||||
|
*/
|
||||||
|
public abstract int size();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final long cost() {
|
||||||
|
return size();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the vector value for the current document ID. It is illegal to call this method when the
|
||||||
|
* iterator is not positioned: before advancing, or after failing to advance. The returned array
|
||||||
|
* may be shared across calls, re-used, and modified as the iterator advances.
|
||||||
|
*
|
||||||
|
* @return the vector value
|
||||||
|
*/
|
||||||
|
public abstract BytesRef vectorValue() throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Return the binary encoded vector value for the current document ID. These are the bytes
|
||||||
|
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
|
||||||
|
* method when the iterator is not positioned: before advancing, or after failing to advance. The
|
||||||
|
* returned storage may be shared across calls, re-used and modified as the iterator advances.
|
||||||
|
*
|
||||||
|
* @return the binary value
|
||||||
|
*/
|
||||||
|
public final BytesRef binaryValue() throws IOException {
|
||||||
|
return vectorValue();
|
||||||
|
}
|
||||||
|
}
|
|
@ -34,6 +34,7 @@ import java.util.Iterator;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Locale;
|
import java.util.Locale;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.Objects;
|
||||||
import java.util.concurrent.Callable;
|
import java.util.concurrent.Callable;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.CompletionException;
|
import java.util.concurrent.CompletionException;
|
||||||
|
@ -2588,62 +2589,37 @@ public final class CheckIndex implements Closeable {
|
||||||
+ "\" has vector values but dimension is "
|
+ "\" has vector values but dimension is "
|
||||||
+ dimension);
|
+ dimension);
|
||||||
}
|
}
|
||||||
VectorValues values = reader.getVectorValues(fieldInfo.name);
|
if (reader.getVectorValues(fieldInfo.name) == null
|
||||||
if (values == null) {
|
&& reader.getByteVectorValues(fieldInfo.name) == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
status.totalKnnVectorFields++;
|
status.totalKnnVectorFields++;
|
||||||
|
|
||||||
int docCount = 0;
|
|
||||||
int everyNdoc = Math.max(values.size() / 64, 1);
|
|
||||||
while (values.nextDoc() != NO_MORE_DOCS) {
|
|
||||||
// search the first maxNumSearches vectors to exercise the graph
|
|
||||||
if (values.docID() % everyNdoc == 0) {
|
|
||||||
TopDocs docs =
|
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
case FLOAT32 -> reader
|
case BYTE:
|
||||||
.getVectorReader()
|
checkByteVectorValues(
|
||||||
.search(
|
Objects.requireNonNull(reader.getByteVectorValues(fieldInfo.name)),
|
||||||
fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
fieldInfo,
|
||||||
case BYTE -> reader
|
status,
|
||||||
.getVectorReader()
|
reader);
|
||||||
.search(
|
break;
|
||||||
fieldInfo.name, values.binaryValue(), 10, null, Integer.MAX_VALUE);
|
case FLOAT32:
|
||||||
};
|
checkFloatVectorValues(
|
||||||
if (docs.scoreDocs.length == 0) {
|
Objects.requireNonNull(reader.getVectorValues(fieldInfo.name)),
|
||||||
throw new CheckIndexException(
|
fieldInfo,
|
||||||
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
status,
|
||||||
}
|
reader);
|
||||||
}
|
break;
|
||||||
float[] vectorValue = values.vectorValue();
|
default:
|
||||||
int valueLength = vectorValue.length;
|
|
||||||
if (valueLength != dimension) {
|
|
||||||
throw new CheckIndexException(
|
throw new CheckIndexException(
|
||||||
"Field \""
|
"Field \""
|
||||||
+ fieldInfo.name
|
+ fieldInfo.name
|
||||||
+ "\" has a value whose dimension="
|
+ "\" has unexpected vector encoding: "
|
||||||
+ valueLength
|
+ fieldInfo.getVectorEncoding());
|
||||||
+ " not matching the field's dimension="
|
}
|
||||||
+ dimension);
|
|
||||||
}
|
|
||||||
++docCount;
|
|
||||||
}
|
|
||||||
if (docCount != values.size()) {
|
|
||||||
throw new CheckIndexException(
|
|
||||||
"Field \""
|
|
||||||
+ fieldInfo.name
|
|
||||||
+ "\" has size="
|
|
||||||
+ values.size()
|
|
||||||
+ " but when iterated, returns "
|
|
||||||
+ docCount
|
|
||||||
+ " docs with values");
|
|
||||||
}
|
|
||||||
status.totalVectorValues += docCount;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
msg(
|
msg(
|
||||||
infoStream,
|
infoStream,
|
||||||
String.format(
|
String.format(
|
||||||
|
@ -2667,6 +2643,96 @@ public final class CheckIndex implements Closeable {
|
||||||
return status;
|
return status;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static void checkFloatVectorValues(
|
||||||
|
VectorValues values,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
CheckIndex.Status.VectorValuesStatus status,
|
||||||
|
CodecReader codecReader)
|
||||||
|
throws IOException {
|
||||||
|
int docCount = 0;
|
||||||
|
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||||
|
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
// search the first maxNumSearches vectors to exercise the graph
|
||||||
|
if (values.docID() % everyNdoc == 0) {
|
||||||
|
TopDocs docs =
|
||||||
|
codecReader
|
||||||
|
.getVectorReader()
|
||||||
|
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||||
|
if (docs.scoreDocs.length == 0) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int valueLength = values.vectorValue().length;
|
||||||
|
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \""
|
||||||
|
+ fieldInfo.name
|
||||||
|
+ "\" has a value whose dimension="
|
||||||
|
+ valueLength
|
||||||
|
+ " not matching the field's dimension="
|
||||||
|
+ fieldInfo.getVectorDimension());
|
||||||
|
}
|
||||||
|
++docCount;
|
||||||
|
}
|
||||||
|
if (docCount != values.size()) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \""
|
||||||
|
+ fieldInfo.name
|
||||||
|
+ "\" has size="
|
||||||
|
+ values.size()
|
||||||
|
+ " but when iterated, returns "
|
||||||
|
+ docCount
|
||||||
|
+ " docs with values");
|
||||||
|
}
|
||||||
|
status.totalVectorValues += docCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static void checkByteVectorValues(
|
||||||
|
ByteVectorValues values,
|
||||||
|
FieldInfo fieldInfo,
|
||||||
|
CheckIndex.Status.VectorValuesStatus status,
|
||||||
|
CodecReader codecReader)
|
||||||
|
throws IOException {
|
||||||
|
int docCount = 0;
|
||||||
|
int everyNdoc = Math.max(values.size() / 64, 1);
|
||||||
|
while (values.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
// search the first maxNumSearches vectors to exercise the graph
|
||||||
|
if (values.docID() % everyNdoc == 0) {
|
||||||
|
TopDocs docs =
|
||||||
|
codecReader
|
||||||
|
.getVectorReader()
|
||||||
|
.search(fieldInfo.name, values.vectorValue(), 10, null, Integer.MAX_VALUE);
|
||||||
|
if (docs.scoreDocs.length == 0) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
int valueLength = values.vectorValue().length;
|
||||||
|
if (valueLength != fieldInfo.getVectorDimension()) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \""
|
||||||
|
+ fieldInfo.name
|
||||||
|
+ "\" has a value whose dimension="
|
||||||
|
+ valueLength
|
||||||
|
+ " not matching the field's dimension="
|
||||||
|
+ fieldInfo.getVectorDimension());
|
||||||
|
}
|
||||||
|
++docCount;
|
||||||
|
}
|
||||||
|
if (docCount != values.size()) {
|
||||||
|
throw new CheckIndexException(
|
||||||
|
"Field \""
|
||||||
|
+ fieldInfo.name
|
||||||
|
+ "\" has size="
|
||||||
|
+ values.size()
|
||||||
|
+ " but when iterated, returns "
|
||||||
|
+ docCount
|
||||||
|
+ " docs with values");
|
||||||
|
}
|
||||||
|
status.totalVectorValues += docCount;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Walks the entire N-dimensional points space, verifying that all points fall within the last
|
* Walks the entire N-dimensional points space, verifying that all points fall within the last
|
||||||
* cell's boundaries.
|
* cell's boundaries.
|
||||||
|
|
|
@ -218,7 +218,9 @@ public abstract class CodecReader extends LeafReader {
|
||||||
public final VectorValues getVectorValues(String field) throws IOException {
|
public final VectorValues getVectorValues(String field) throws IOException {
|
||||||
ensureOpen();
|
ensureOpen();
|
||||||
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||||
if (fi == null || fi.getVectorDimension() == 0) {
|
if (fi == null
|
||||||
|
|| fi.getVectorDimension() == 0
|
||||||
|
|| fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||||
// Field does not exist or does not index vectors
|
// Field does not exist or does not index vectors
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -226,6 +228,20 @@ public abstract class CodecReader extends LeafReader {
|
||||||
return getVectorReader().getVectorValues(field);
|
return getVectorReader().getVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
ensureOpen();
|
||||||
|
FieldInfo fi = getFieldInfos().fieldInfo(field);
|
||||||
|
if (fi == null
|
||||||
|
|| fi.getVectorDimension() == 0
|
||||||
|
|| fi.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||||
|
// Field does not exist or does not index vectors
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return getVectorReader().getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public final TopDocs searchNearestVectors(
|
public final TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||||
|
|
|
@ -53,6 +53,11 @@ abstract class DocValuesLeafReader extends LeafReader {
|
||||||
throw new UnsupportedOperationException();
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public final ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||||
|
|
|
@ -323,6 +323,15 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
||||||
return new ExitableVectorValues(vectorValues);
|
return new ExitableVectorValues(vectorValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
final ByteVectorValues vectorValues = in.getByteVectorValues(field);
|
||||||
|
if (vectorValues == null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return new ExitableByteVectorValues(vectorValues);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||||
|
@ -387,17 +396,18 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private class ExitableVectorValues extends FilterVectorValues {
|
private class ExitableVectorValues extends VectorValues {
|
||||||
private int docToCheck;
|
private int docToCheck;
|
||||||
|
private final VectorValues vectorValues;
|
||||||
|
|
||||||
public ExitableVectorValues(VectorValues vectorValues) {
|
public ExitableVectorValues(VectorValues vectorValues) {
|
||||||
super(vectorValues);
|
this.vectorValues = vectorValues;
|
||||||
docToCheck = 0;
|
docToCheck = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int advance(int target) throws IOException {
|
public int advance(int target) throws IOException {
|
||||||
final int advance = super.advance(target);
|
final int advance = vectorValues.advance(target);
|
||||||
if (advance >= docToCheck) {
|
if (advance >= docToCheck) {
|
||||||
checkAndThrow();
|
checkAndThrow();
|
||||||
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||||
|
@ -405,9 +415,14 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
||||||
return advance;
|
return advance;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return vectorValues.docID();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int nextDoc() throws IOException {
|
public int nextDoc() throws IOException {
|
||||||
final int nextDoc = super.nextDoc();
|
final int nextDoc = vectorValues.nextDoc();
|
||||||
if (nextDoc >= docToCheck) {
|
if (nextDoc >= docToCheck) {
|
||||||
checkAndThrow();
|
checkAndThrow();
|
||||||
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||||
|
@ -415,14 +430,91 @@ public class ExitableDirectoryReader extends FilterDirectoryReader {
|
||||||
return nextDoc;
|
return nextDoc;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return vectorValues.dimension();
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue() throws IOException {
|
public float[] vectorValue() throws IOException {
|
||||||
return in.vectorValue();
|
return vectorValues.vectorValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return vectorValues.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue() throws IOException {
|
public BytesRef binaryValue() throws IOException {
|
||||||
return in.binaryValue();
|
return vectorValues.binaryValue();
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or
|
||||||
|
* if {@link Thread#interrupted()} returns true.
|
||||||
|
*/
|
||||||
|
private void checkAndThrow() {
|
||||||
|
if (queryTimeout.shouldExit()) {
|
||||||
|
throw new ExitingReaderException(
|
||||||
|
"The request took too long to iterate over vector values. Timeout: "
|
||||||
|
+ queryTimeout.toString()
|
||||||
|
+ ", VectorValues="
|
||||||
|
+ in);
|
||||||
|
} else if (Thread.interrupted()) {
|
||||||
|
throw new ExitingReaderException(
|
||||||
|
"Interrupted while iterating over vector values. VectorValues=" + in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private class ExitableByteVectorValues extends ByteVectorValues {
|
||||||
|
private int docToCheck;
|
||||||
|
private final ByteVectorValues vectorValues;
|
||||||
|
|
||||||
|
public ExitableByteVectorValues(ByteVectorValues vectorValues) {
|
||||||
|
this.vectorValues = vectorValues;
|
||||||
|
docToCheck = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
final int advance = vectorValues.advance(target);
|
||||||
|
if (advance >= docToCheck) {
|
||||||
|
checkAndThrow();
|
||||||
|
docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||||
|
}
|
||||||
|
return advance;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return vectorValues.docID();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
final int nextDoc = vectorValues.nextDoc();
|
||||||
|
if (nextDoc >= docToCheck) {
|
||||||
|
checkAndThrow();
|
||||||
|
docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK;
|
||||||
|
}
|
||||||
|
return nextDoc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return vectorValues.dimension();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return vectorValues.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
return vectorValues.vectorValue();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -351,6 +351,11 @@ public abstract class FilterLeafReader extends LeafReader {
|
||||||
return in.getVectorValues(field);
|
return in.getVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return in.getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||||
|
|
|
@ -38,6 +38,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
||||||
import org.apache.lucene.codecs.PointsFormat;
|
import org.apache.lucene.codecs.PointsFormat;
|
||||||
import org.apache.lucene.codecs.PointsWriter;
|
import org.apache.lucene.codecs.PointsWriter;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
|
@ -721,11 +722,7 @@ final class IndexingChain implements Accountable {
|
||||||
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
pf.pointValuesWriter.addPackedValue(docID, field.binaryValue());
|
||||||
}
|
}
|
||||||
if (fieldType.vectorDimension() != 0) {
|
if (fieldType.vectorDimension() != 0) {
|
||||||
switch (fieldType.vectorEncoding()) {
|
indexVectorValue(docID, pf, fieldType.vectorEncoding(), field);
|
||||||
case BYTE -> pf.knnFieldVectorsWriter.addValue(docID, field.binaryValue());
|
|
||||||
case FLOAT32 -> pf.knnFieldVectorsWriter.addValue(
|
|
||||||
docID, ((KnnVectorField) field).vectorValue());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return indexedField;
|
return indexedField;
|
||||||
}
|
}
|
||||||
|
@ -959,6 +956,18 @@ final class IndexingChain implements Accountable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
private void indexVectorValue(
|
||||||
|
int docID, PerField pf, VectorEncoding vectorEncoding, IndexableField field)
|
||||||
|
throws IOException {
|
||||||
|
switch (vectorEncoding) {
|
||||||
|
case BYTE -> ((KnnFieldVectorsWriter<BytesRef>) pf.knnFieldVectorsWriter)
|
||||||
|
.addValue(docID, ((KnnByteVectorField) field).vectorValue());
|
||||||
|
case FLOAT32 -> ((KnnFieldVectorsWriter<float[]>) pf.knnFieldVectorsWriter)
|
||||||
|
.addValue(docID, ((KnnVectorField) field).vectorValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */
|
/** Returns a previously created {@link PerField}, or null if this field name wasn't seen yet. */
|
||||||
private PerField getPerField(String name) {
|
private PerField getPerField(String name) {
|
||||||
final int hashPos = name.hashCode() & hashMask;
|
final int hashPos = name.hashCode() & hashMask;
|
||||||
|
|
|
@ -208,6 +208,14 @@ public abstract class LeafReader extends IndexReader {
|
||||||
*/
|
*/
|
||||||
public abstract VectorValues getVectorValues(String field) throws IOException;
|
public abstract VectorValues getVectorValues(String field) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns {@link ByteVectorValues} for this field, or null if no {@link ByteVectorValues} were
|
||||||
|
* indexed. The returned instance should only be used by a single thread.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
* Return the k nearest neighbor documents as determined by comparison of their vector values for
|
||||||
* this field, to the given vector, by the field's similarity function. The score of each document
|
* this field, to the given vector, by the field's similarity function. The score of each document
|
||||||
|
|
|
@ -408,6 +408,13 @@ public class ParallelLeafReader extends LeafReader {
|
||||||
return reader == null ? null : reader.getVectorValues(fieldName);
|
return reader == null ? null : reader.getVectorValues(fieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
|
||||||
|
ensureOpen();
|
||||||
|
LeafReader reader = fieldToReader.get(fieldName);
|
||||||
|
return reader == null ? null : reader.getByteVectorValues(fieldName);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
String fieldName, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
|
|
|
@ -168,6 +168,11 @@ public final class SlowCodecReaderWrapper {
|
||||||
return reader.getVectorValues(field);
|
return reader.getVectorValues(field);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return reader.getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
|
@ -222,36 +222,23 @@ public final class SortingCodecReader extends FilterCodecReader {
|
||||||
final FixedBitSet docsWithField;
|
final FixedBitSet docsWithField;
|
||||||
final float[][] vectors;
|
final float[][] vectors;
|
||||||
final ByteBuffer vectorAsBytes;
|
final ByteBuffer vectorAsBytes;
|
||||||
final BytesRef[] binaryVectors;
|
|
||||||
|
|
||||||
private int docId = -1;
|
private int docId = -1;
|
||||||
|
|
||||||
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap, VectorEncoding encoding)
|
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||||
throws IOException {
|
|
||||||
this.size = delegate.size();
|
this.size = delegate.size();
|
||||||
this.dimension = delegate.dimension();
|
this.dimension = delegate.dimension();
|
||||||
docsWithField = new FixedBitSet(sortMap.size());
|
docsWithField = new FixedBitSet(sortMap.size());
|
||||||
if (encoding == VectorEncoding.BYTE) {
|
|
||||||
vectors = null;
|
|
||||||
binaryVectors = new BytesRef[sortMap.size()];
|
|
||||||
vectorAsBytes = null;
|
|
||||||
} else {
|
|
||||||
vectors = new float[sortMap.size()][];
|
vectors = new float[sortMap.size()][];
|
||||||
binaryVectors = null;
|
|
||||||
vectorAsBytes =
|
vectorAsBytes =
|
||||||
ByteBuffer.allocate(delegate.dimension() * encoding.byteSize)
|
ByteBuffer.allocate(delegate.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||||
.order(ByteOrder.LITTLE_ENDIAN);
|
.order(ByteOrder.LITTLE_ENDIAN);
|
||||||
}
|
|
||||||
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||||
int newDocID = sortMap.oldToNew(doc);
|
int newDocID = sortMap.oldToNew(doc);
|
||||||
docsWithField.set(newDocID);
|
docsWithField.set(newDocID);
|
||||||
if (encoding == VectorEncoding.BYTE) {
|
|
||||||
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.binaryValue());
|
|
||||||
} else {
|
|
||||||
vectors[newDocID] = delegate.vectorValue().clone();
|
vectors[newDocID] = delegate.vectorValue().clone();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public int docID() {
|
public int docID() {
|
||||||
|
@ -265,13 +252,9 @@ public final class SortingCodecReader extends FilterCodecReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue() throws IOException {
|
public BytesRef binaryValue() throws IOException {
|
||||||
if (binaryVectors != null) {
|
|
||||||
return binaryVectors[docId];
|
|
||||||
} else {
|
|
||||||
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
|
vectorAsBytes.asFloatBuffer().put(vectors[docId]);
|
||||||
return new BytesRef(vectorAsBytes.array());
|
return new BytesRef(vectorAsBytes.array());
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float[] vectorValue() throws IOException {
|
public float[] vectorValue() throws IOException {
|
||||||
|
@ -297,6 +280,60 @@ public final class SortingCodecReader extends FilterCodecReader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static class SortingByteVectorValues extends ByteVectorValues {
|
||||||
|
final int size;
|
||||||
|
final int dimension;
|
||||||
|
final FixedBitSet docsWithField;
|
||||||
|
final BytesRef[] binaryVectors;
|
||||||
|
|
||||||
|
private int docId = -1;
|
||||||
|
|
||||||
|
SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||||
|
this.size = delegate.size();
|
||||||
|
this.dimension = delegate.dimension();
|
||||||
|
docsWithField = new FixedBitSet(sortMap.size());
|
||||||
|
binaryVectors = new BytesRef[sortMap.size()];
|
||||||
|
for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) {
|
||||||
|
int newDocID = sortMap.oldToNew(doc);
|
||||||
|
docsWithField.set(newDocID);
|
||||||
|
binaryVectors[newDocID] = BytesRef.deepCopyOf(delegate.vectorValue());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return docId;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() throws IOException {
|
||||||
|
return advance(docId + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() throws IOException {
|
||||||
|
return binaryVectors[docId];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) throws IOException {
|
||||||
|
if (target >= docsWithField.length()) {
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return docId = docsWithField.nextSetBit(target);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
|
* Return a sorted view of <code>reader</code> according to the order defined by <code>sort</code>
|
||||||
* . If the reader is already sorted, this method might return the reader as-is.
|
* . If the reader is already sorted, this method might return the reader as-is.
|
||||||
|
@ -465,9 +502,12 @@ public final class SortingCodecReader extends FilterCodecReader {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
FieldInfo fi = in.getFieldInfos().fieldInfo(field);
|
return new SortingVectorValues(delegate.getVectorValues(field), docMap);
|
||||||
return new SortingVectorValues(
|
}
|
||||||
delegate.getVectorValues(field), docMap, fi.getVectorEncoding());
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return new SortingByteVectorValues(delegate.getByteVectorValues(field), docMap);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -61,8 +61,8 @@ public abstract class VectorValues extends DocIdSetIterator {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the binary encoded vector value for the current document ID. These are the bytes
|
* Return the binary encoded vector value for the current document ID. These are the bytes
|
||||||
* corresponding to the float array return by {@link #vectorValue}. It is illegal to call this
|
* corresponding to the array return by {@link #vectorValue}. It is illegal to call this method
|
||||||
* method when the iterator is not positioned: before advancing, or after failing to advance. The
|
* when the iterator is not positioned: before advancing, or after failing to advance. The
|
||||||
* returned storage may be shared across calls, re-used and modified as the iterator advances.
|
* returned storage may be shared across calls, re-used and modified as the iterator advances.
|
||||||
*
|
*
|
||||||
* @return the binary value
|
* @return the binary value
|
||||||
|
|
|
@ -31,7 +31,8 @@ import org.apache.lucene.index.Terms;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A {@link Query} that matches documents that contain either a {@link
|
* A {@link Query} that matches documents that contain either a {@link
|
||||||
* org.apache.lucene.document.KnnVectorField}, or a field that indexes norms or doc values.
|
* org.apache.lucene.document.KnnVectorField}, {@link org.apache.lucene.document.KnnByteVectorField}
|
||||||
|
* or a field that indexes norms or doc values.
|
||||||
*/
|
*/
|
||||||
public class FieldExistsQuery extends Query {
|
public class FieldExistsQuery extends Query {
|
||||||
private String field;
|
private String field;
|
||||||
|
@ -127,7 +128,12 @@ public class FieldExistsQuery extends Query {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
||||||
if (leaf.getVectorValues(field).size() != leaf.maxDoc()) {
|
int numVectors =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case FLOAT32 -> leaf.getVectorValues(field).size();
|
||||||
|
case BYTE -> leaf.getByteVectorValues(field).size();
|
||||||
|
};
|
||||||
|
if (numVectors != leaf.maxDoc()) {
|
||||||
allReadersRewritable = false;
|
allReadersRewritable = false;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -175,7 +181,11 @@ public class FieldExistsQuery extends Query {
|
||||||
if (fieldInfo.hasNorms()) { // the field indexes norms
|
if (fieldInfo.hasNorms()) { // the field indexes norms
|
||||||
iterator = context.reader().getNormValues(field);
|
iterator = context.reader().getNormValues(field);
|
||||||
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
} else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors
|
||||||
iterator = context.reader().getVectorValues(field);
|
iterator =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case FLOAT32 -> context.reader().getVectorValues(field);
|
||||||
|
case BYTE -> context.reader().getByteVectorValues(field);
|
||||||
|
};
|
||||||
} else if (fieldInfo.getDocValuesType()
|
} else if (fieldInfo.getDocValuesType()
|
||||||
!= DocValuesType.NONE) { // the field indexes doc values
|
!= DocValuesType.NONE) { // the field indexes doc values
|
||||||
switch (fieldInfo.getDocValuesType()) {
|
switch (fieldInfo.getDocValuesType()) {
|
||||||
|
|
|
@ -54,7 +54,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
||||||
* @param k the number of documents to find
|
* @param k the number of documents to find
|
||||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||||
*/
|
*/
|
||||||
public KnnByteVectorQuery(String field, byte[] target, int k) {
|
public KnnByteVectorQuery(String field, BytesRef target, int k) {
|
||||||
this(field, target, k, null);
|
this(field, target, k, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,9 +68,9 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery {
|
||||||
* @param filter a filter applied before the vector search
|
* @param filter a filter applied before the vector search
|
||||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||||
*/
|
*/
|
||||||
public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) {
|
public KnnByteVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||||
super(field, k, filter);
|
super(field, k, filter);
|
||||||
this.target = new BytesRef(target);
|
this.target = target;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
@ -29,7 +30,6 @@ import org.apache.lucene.util.BytesRef;
|
||||||
* search over the vectors.
|
* search over the vectors.
|
||||||
*/
|
*/
|
||||||
abstract class VectorScorer {
|
abstract class VectorScorer {
|
||||||
protected final VectorValues values;
|
|
||||||
protected final VectorSimilarityFunction similarity;
|
protected final VectorSimilarityFunction similarity;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -48,20 +48,36 @@ abstract class VectorScorer {
|
||||||
|
|
||||||
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
|
static ByteVectorScorer create(LeafReaderContext context, FieldInfo fi, BytesRef query)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
VectorValues values = context.reader().getVectorValues(fi.name);
|
ByteVectorValues values = context.reader().getByteVectorValues(fi.name);
|
||||||
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
VectorSimilarityFunction similarity = fi.getVectorSimilarityFunction();
|
||||||
return new ByteVectorScorer(values, query, similarity);
|
return new ByteVectorScorer(values, query, similarity);
|
||||||
}
|
}
|
||||||
|
|
||||||
VectorScorer(VectorValues values, VectorSimilarityFunction similarity) {
|
VectorScorer(VectorSimilarityFunction similarity) {
|
||||||
this.values = values;
|
|
||||||
this.similarity = similarity;
|
this.similarity = similarity;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Compute the similarity score for the current document. */
|
||||||
|
abstract float score() throws IOException;
|
||||||
|
|
||||||
|
abstract boolean advanceExact(int doc) throws IOException;
|
||||||
|
|
||||||
|
private static class ByteVectorScorer extends VectorScorer {
|
||||||
|
private final BytesRef query;
|
||||||
|
private final ByteVectorValues values;
|
||||||
|
|
||||||
|
protected ByteVectorScorer(
|
||||||
|
ByteVectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
||||||
|
super(similarity);
|
||||||
|
this.values = values;
|
||||||
|
this.query = query;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Advance the instance to the given document ID and return true if there is a value for that
|
* Advance the instance to the given document ID and return true if there is a value for that
|
||||||
* document.
|
* document.
|
||||||
*/
|
*/
|
||||||
|
@Override
|
||||||
public boolean advanceExact(int doc) throws IOException {
|
public boolean advanceExact(int doc) throws IOException {
|
||||||
int vectorDoc = values.docID();
|
int vectorDoc = values.docID();
|
||||||
if (vectorDoc < doc) {
|
if (vectorDoc < doc) {
|
||||||
|
@ -70,31 +86,34 @@ abstract class VectorScorer {
|
||||||
return vectorDoc == doc;
|
return vectorDoc == doc;
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Compute the similarity score for the current document. */
|
|
||||||
abstract float score() throws IOException;
|
|
||||||
|
|
||||||
private static class ByteVectorScorer extends VectorScorer {
|
|
||||||
private final BytesRef query;
|
|
||||||
|
|
||||||
protected ByteVectorScorer(
|
|
||||||
VectorValues values, BytesRef query, VectorSimilarityFunction similarity) {
|
|
||||||
super(values, similarity);
|
|
||||||
this.query = query;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public float score() throws IOException {
|
public float score() throws IOException {
|
||||||
return similarity.compare(query, values.binaryValue());
|
return similarity.compare(query, values.vectorValue());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class FloatVectorScorer extends VectorScorer {
|
private static class FloatVectorScorer extends VectorScorer {
|
||||||
private final float[] query;
|
private final float[] query;
|
||||||
|
private final VectorValues values;
|
||||||
|
|
||||||
protected FloatVectorScorer(
|
protected FloatVectorScorer(
|
||||||
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
VectorValues values, float[] query, VectorSimilarityFunction similarity) {
|
||||||
super(values, similarity);
|
super(similarity);
|
||||||
this.query = query;
|
this.query = query;
|
||||||
|
this.values = values;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Advance the instance to the given document ID and return true if there is a value for that
|
||||||
|
* document.
|
||||||
|
*/
|
||||||
|
@Override
|
||||||
|
public boolean advanceExact(int doc) throws IOException {
|
||||||
|
int vectorDoc = values.docID();
|
||||||
|
if (vectorDoc < doc) {
|
||||||
|
vectorDoc = values.advance(doc);
|
||||||
|
}
|
||||||
|
return vectorDoc == doc;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -53,7 +53,7 @@ public final class HnswGraphBuilder<T> {
|
||||||
|
|
||||||
private final VectorSimilarityFunction similarityFunction;
|
private final VectorSimilarityFunction similarityFunction;
|
||||||
private final VectorEncoding vectorEncoding;
|
private final VectorEncoding vectorEncoding;
|
||||||
private final RandomAccessVectorValues vectors;
|
private final RandomAccessVectorValues<T> vectors;
|
||||||
private final SplittableRandom random;
|
private final SplittableRandom random;
|
||||||
private final HnswGraphSearcher<T> graphSearcher;
|
private final HnswGraphSearcher<T> graphSearcher;
|
||||||
|
|
||||||
|
@ -63,10 +63,10 @@ public final class HnswGraphBuilder<T> {
|
||||||
|
|
||||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||||
// colliding
|
// colliding
|
||||||
private final RandomAccessVectorValues vectorsCopy;
|
private final RandomAccessVectorValues<T> vectorsCopy;
|
||||||
|
|
||||||
public static HnswGraphBuilder<?> create(
|
public static <T> HnswGraphBuilder<T> create(
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<T> vectors,
|
||||||
VectorEncoding vectorEncoding,
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
int M,
|
int M,
|
||||||
|
@ -89,7 +89,7 @@ public final class HnswGraphBuilder<T> {
|
||||||
* to ensure repeatable construction.
|
* to ensure repeatable construction.
|
||||||
*/
|
*/
|
||||||
private HnswGraphBuilder(
|
private HnswGraphBuilder(
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<T> vectors,
|
||||||
VectorEncoding vectorEncoding,
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
int M,
|
int M,
|
||||||
|
@ -131,7 +131,7 @@ public final class HnswGraphBuilder<T> {
|
||||||
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
|
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
|
||||||
* independent accessor for the vectors
|
* independent accessor for the vectors
|
||||||
*/
|
*/
|
||||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
public OnHeapHnswGraph build(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
|
||||||
if (vectorsToAdd == this.vectors) {
|
if (vectorsToAdd == this.vectors) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||||
|
@ -143,7 +143,7 @@ public final class HnswGraphBuilder<T> {
|
||||||
return hnsw;
|
return hnsw;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
private void addVectors(RandomAccessVectorValues<T> vectorsToAdd) throws IOException {
|
||||||
long start = System.nanoTime(), t = start;
|
long start = System.nanoTime(), t = start;
|
||||||
// start at node 1! node 0 is added implicitly, in the constructor
|
// start at node 1! node 0 is added implicitly, in the constructor
|
||||||
for (int node = 1; node < vectorsToAdd.size(); node++) {
|
for (int node = 1; node < vectorsToAdd.size(); node++) {
|
||||||
|
@ -189,16 +189,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addGraphNode(int node, RandomAccessVectorValues values) throws IOException {
|
public void addGraphNode(int node, RandomAccessVectorValues<T> values) throws IOException {
|
||||||
addGraphNode(node, getValue(node, values));
|
addGraphNode(node, values.vectorValue(node));
|
||||||
}
|
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
|
||||||
private T getValue(int node, RandomAccessVectorValues values) throws IOException {
|
|
||||||
return switch (vectorEncoding) {
|
|
||||||
case BYTE -> (T) values.binaryValue(node);
|
|
||||||
case FLOAT32 -> (T) values.vectorValue(node);
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private long printGraphBuildStatus(int node, long start, long t) {
|
private long printGraphBuildStatus(int node, long start, long t) {
|
||||||
|
@ -281,8 +273,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return switch (vectorEncoding) {
|
return switch (vectorEncoding) {
|
||||||
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score);
|
case BYTE -> isDiverse((BytesRef) vectors.vectorValue(candidate), neighbors, score);
|
||||||
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score);
|
case FLOAT32 -> isDiverse((float[]) vectors.vectorValue(candidate), neighbors, score);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -290,7 +282,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
throws IOException {
|
throws IOException {
|
||||||
for (int i = 0; i < neighbors.size(); i++) {
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
float neighborSimilarity =
|
float neighborSimilarity =
|
||||||
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
|
similarityFunction.compare(
|
||||||
|
candidate, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||||
if (neighborSimilarity >= score) {
|
if (neighborSimilarity >= score) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -302,7 +295,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
throws IOException {
|
throws IOException {
|
||||||
for (int i = 0; i < neighbors.size(); i++) {
|
for (int i = 0; i < neighbors.size(); i++) {
|
||||||
float neighborSimilarity =
|
float neighborSimilarity =
|
||||||
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
|
similarityFunction.compare(
|
||||||
|
candidate, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||||
if (neighborSimilarity >= score) {
|
if (neighborSimilarity >= score) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -327,9 +321,10 @@ public final class HnswGraphBuilder<T> {
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int candidateNode = neighbors.node[candidateIndex];
|
int candidateNode = neighbors.node[candidateIndex];
|
||||||
return switch (vectorEncoding) {
|
return switch (vectorEncoding) {
|
||||||
case BYTE -> isWorstNonDiverse(candidateIndex, vectors.binaryValue(candidateNode), neighbors);
|
case BYTE -> isWorstNonDiverse(
|
||||||
|
candidateIndex, (BytesRef) vectors.vectorValue(candidateNode), neighbors);
|
||||||
case FLOAT32 -> isWorstNonDiverse(
|
case FLOAT32 -> isWorstNonDiverse(
|
||||||
candidateIndex, vectors.vectorValue(candidateNode), neighbors);
|
candidateIndex, (float[]) vectors.vectorValue(candidateNode), neighbors);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -338,7 +333,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||||
float neighborSimilarity =
|
float neighborSimilarity =
|
||||||
similarityFunction.compare(candidateVector, vectorsCopy.vectorValue(neighbors.node[i]));
|
similarityFunction.compare(
|
||||||
|
candidateVector, (float[]) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||||
// candidate node is too similar to node i given its score relative to the base node
|
// candidate node is too similar to node i given its score relative to the base node
|
||||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -352,7 +348,8 @@ public final class HnswGraphBuilder<T> {
|
||||||
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
float minAcceptedSimilarity = neighbors.score[candidateIndex];
|
||||||
for (int i = candidateIndex - 1; i >= 0; i--) {
|
for (int i = candidateIndex - 1; i >= 0; i--) {
|
||||||
float neighborSimilarity =
|
float neighborSimilarity =
|
||||||
similarityFunction.compare(candidateVector, vectorsCopy.binaryValue(neighbors.node[i]));
|
similarityFunction.compare(
|
||||||
|
candidateVector, (BytesRef) vectorsCopy.vectorValue(neighbors.node[i]));
|
||||||
// candidate node is too similar to node i given its score relative to the base node
|
// candidate node is too similar to node i given its score relative to the base node
|
||||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -81,7 +81,7 @@ public class HnswGraphSearcher<T> {
|
||||||
public static NeighborQueue search(
|
public static NeighborQueue search(
|
||||||
float[] query,
|
float[] query,
|
||||||
int topK,
|
int topK,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<float[]> vectors,
|
||||||
VectorEncoding vectorEncoding,
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
HnswGraph graph,
|
HnswGraph graph,
|
||||||
|
@ -137,7 +137,7 @@ public class HnswGraphSearcher<T> {
|
||||||
public static NeighborQueue search(
|
public static NeighborQueue search(
|
||||||
BytesRef query,
|
BytesRef query,
|
||||||
int topK,
|
int topK,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<BytesRef> vectors,
|
||||||
VectorEncoding vectorEncoding,
|
VectorEncoding vectorEncoding,
|
||||||
VectorSimilarityFunction similarityFunction,
|
VectorSimilarityFunction similarityFunction,
|
||||||
HnswGraph graph,
|
HnswGraph graph,
|
||||||
|
@ -198,7 +198,7 @@ public class HnswGraphSearcher<T> {
|
||||||
int topK,
|
int topK,
|
||||||
int level,
|
int level,
|
||||||
final int[] eps,
|
final int[] eps,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<T> vectors,
|
||||||
HnswGraph graph)
|
HnswGraph graph)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
|
return searchLevel(query, topK, level, eps, vectors, graph, null, Integer.MAX_VALUE);
|
||||||
|
@ -209,7 +209,7 @@ public class HnswGraphSearcher<T> {
|
||||||
int topK,
|
int topK,
|
||||||
int level,
|
int level,
|
||||||
final int[] eps,
|
final int[] eps,
|
||||||
RandomAccessVectorValues vectors,
|
RandomAccessVectorValues<T> vectors,
|
||||||
HnswGraph graph,
|
HnswGraph graph,
|
||||||
Bits acceptOrds,
|
Bits acceptOrds,
|
||||||
int visitedLimit)
|
int visitedLimit)
|
||||||
|
@ -279,11 +279,11 @@ public class HnswGraphSearcher<T> {
|
||||||
return results;
|
return results;
|
||||||
}
|
}
|
||||||
|
|
||||||
private float compare(T query, RandomAccessVectorValues vectors, int ord) throws IOException {
|
private float compare(T query, RandomAccessVectorValues<T> vectors, int ord) throws IOException {
|
||||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
if (vectorEncoding == VectorEncoding.BYTE) {
|
||||||
return similarityFunction.compare((BytesRef) query, vectors.binaryValue(ord));
|
return similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(ord));
|
||||||
} else {
|
} else {
|
||||||
return similarityFunction.compare((float[]) query, vectors.vectorValue(ord));
|
return similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(ord));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,6 @@
|
||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.util.BytesRef;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
||||||
|
@ -26,7 +25,7 @@ import org.apache.lucene.util.BytesRef;
|
||||||
*
|
*
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public interface RandomAccessVectorValues {
|
public interface RandomAccessVectorValues<T> {
|
||||||
|
|
||||||
/** Return the number of vector values */
|
/** Return the number of vector values */
|
||||||
int size();
|
int size();
|
||||||
|
@ -35,26 +34,16 @@ public interface RandomAccessVectorValues {
|
||||||
int dimension();
|
int dimension();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the vector value indexed at the given ordinal. The provided floating point array may be
|
* Return the vector value indexed at the given ordinal.
|
||||||
* shared and overwritten by subsequent calls to this method and {@link #binaryValue(int)}.
|
|
||||||
*
|
*
|
||||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||||
*/
|
*/
|
||||||
float[] vectorValue(int targetOrd) throws IOException;
|
T vectorValue(int targetOrd) throws IOException;
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the vector indexed at the given ordinal value as an array of bytes in a BytesRef; these
|
|
||||||
* are the bytes corresponding to the float array. The provided bytes may be shared and
|
|
||||||
* overwritten by subsequent calls to this method and {@link #vectorValue(int)}.
|
|
||||||
*
|
|
||||||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
|
||||||
*/
|
|
||||||
BytesRef binaryValue(int targetOrd) throws IOException;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
|
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
|
||||||
* access different values at once, to avoid overwriting the underlying float vector returned by
|
* access different values at once, to avoid overwriting the underlying float vector returned by
|
||||||
* {@link RandomAccessVectorValues#vectorValue}.
|
* {@link RandomAccessVectorValues#vectorValue}.
|
||||||
*/
|
*/
|
||||||
RandomAccessVectorValues copy() throws IOException;
|
RandomAccessVectorValues<T> copy() throws IOException;
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
import java.io.StringReader;
|
import java.io.StringReader;
|
||||||
import java.nio.charset.StandardCharsets;
|
import java.nio.charset.StandardCharsets;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
|
@ -611,25 +612,22 @@ public class TestField extends LuceneTestCase {
|
||||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
BytesRef br = newBytesRef(new byte[5]);
|
BytesRef br = newBytesRef(new byte[5]);
|
||||||
Field field = new KnnVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
Field field = new KnnByteVectorField("binary", br, VectorSimilarityFunction.EUCLIDEAN);
|
||||||
expectThrows(
|
expectThrows(
|
||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
|
() -> new KnnVectorField("bogus", new float[] {1}, (FieldType) field.fieldType()));
|
||||||
float[] vector = new float[] {1, 2};
|
float[] vector = new float[] {1, 2};
|
||||||
Field field2 = new KnnVectorField("float", vector);
|
Field field2 = new KnnVectorField("float", vector);
|
||||||
expectThrows(
|
|
||||||
IllegalArgumentException.class,
|
|
||||||
() -> new KnnVectorField("bogus", br, (FieldType) field2.fieldType()));
|
|
||||||
assertEquals(br, field.binaryValue());
|
assertEquals(br, field.binaryValue());
|
||||||
doc.add(field);
|
doc.add(field);
|
||||||
doc.add(field2);
|
doc.add(field2);
|
||||||
w.addDocument(doc);
|
w.addDocument(doc);
|
||||||
try (IndexReader r = DirectoryReader.open(w)) {
|
try (IndexReader r = DirectoryReader.open(w)) {
|
||||||
VectorValues binary = r.leaves().get(0).reader().getVectorValues("binary");
|
ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary");
|
||||||
assertEquals(1, binary.size());
|
assertEquals(1, binary.size());
|
||||||
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
assertNotEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||||
assertEquals(br, binary.binaryValue());
|
|
||||||
assertNotNull(binary.vectorValue());
|
assertNotNull(binary.vectorValue());
|
||||||
|
assertEquals(br, binary.vectorValue());
|
||||||
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
assertEquals(NO_MORE_DOCS, binary.nextDoc());
|
||||||
|
|
||||||
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
|
VectorValues floatValues = r.leaves().get(0).reader().getVectorValues("float");
|
||||||
|
|
|
@ -112,6 +112,11 @@ public class TestSegmentToThreadMapping extends LuceneTestCase {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
package org.apache.lucene.search;
|
package org.apache.lucene.search;
|
||||||
|
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
@ -27,12 +27,12 @@ import org.apache.lucene.util.TestVectorUtil;
|
||||||
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||||
@Override
|
@Override
|
||||||
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
|
AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) {
|
||||||
return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter);
|
return new KnnByteVectorQuery(field, new BytesRef(floatToBytes(query)), k, queryFilter);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
|
AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) {
|
||||||
return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query);
|
return new ThrowingKnnVectorQuery(field, new BytesRef(floatToBytes(vec)), k, query);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -49,12 +49,12 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||||
@Override
|
@Override
|
||||||
Field getKnnVectorField(
|
Field getKnnVectorField(
|
||||||
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||||
return new KnnVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
return new KnnByteVectorField(name, new BytesRef(floatToBytes(vector)), similarityFunction);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Field getKnnVectorField(String name, float[] vector) {
|
Field getKnnVectorField(String name, float[] vector) {
|
||||||
return new KnnVectorField(
|
return new KnnByteVectorField(
|
||||||
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
|
name, new BytesRef(floatToBytes(vector)), VectorSimilarityFunction.EUCLIDEAN);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
|
||||||
|
|
||||||
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
|
private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery {
|
||||||
|
|
||||||
public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) {
|
public ThrowingKnnVectorQuery(String field, BytesRef target, int k, Query filter) {
|
||||||
super(field, target, k, filter);
|
super(field, target, k, filter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ import com.carrotsearch.randomizedtesting.generators.RandomPicks;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.document.StringField;
|
import org.apache.lucene.document.StringField;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
@ -79,7 +80,7 @@ public class TestVectorScorer extends LuceneTestCase {
|
||||||
for (int j = 0; j < v.length; j++) {
|
for (int j = 0; j < v.length; j++) {
|
||||||
v.bytes[j] = (byte) contents[i][j];
|
v.bytes[j] = (byte) contents[i][j];
|
||||||
}
|
}
|
||||||
doc.add(new KnnVectorField(field, v, EUCLIDEAN));
|
doc.add(new KnnByteVectorField(field, v, EUCLIDEAN));
|
||||||
} else {
|
} else {
|
||||||
doc.add(new KnnVectorField(field, contents[i]));
|
doc.add(new KnnVectorField(field, contents[i]));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
abstract class AbstractMockVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||||
|
|
||||||
|
protected final int dimension;
|
||||||
|
protected final T[] denseValues;
|
||||||
|
protected final T[] values;
|
||||||
|
protected final int numVectors;
|
||||||
|
protected final BytesRef binaryValue;
|
||||||
|
|
||||||
|
protected int pos = -1;
|
||||||
|
|
||||||
|
AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) {
|
||||||
|
this.dimension = dimension;
|
||||||
|
this.values = values;
|
||||||
|
this.denseValues = denseValues;
|
||||||
|
// used by tests that build a graph from bytes rather than floats
|
||||||
|
binaryValue = new BytesRef(dimension);
|
||||||
|
binaryValue.length = dimension;
|
||||||
|
this.numVectors = numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return dimension;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public T vectorValue(int targetOrd) {
|
||||||
|
return denseValues[targetOrd];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public abstract AbstractMockVectorValues<T> copy();
|
||||||
|
|
||||||
|
public abstract T vectorValue() throws IOException;
|
||||||
|
|
||||||
|
private boolean seek(int target) {
|
||||||
|
if (target >= 0 && target < values.length && values[target] != null) {
|
||||||
|
pos = target;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int docID() {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int nextDoc() {
|
||||||
|
return advance(pos + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
public int advance(int target) {
|
||||||
|
while (++pos < values.length) {
|
||||||
|
if (seek(pos)) {
|
||||||
|
return pos;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,7 +20,6 @@ package org.apache.lucene.util.hnsw;
|
||||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
|
import static org.apache.lucene.tests.util.RamUsageTester.ramUsed;
|
||||||
import static org.apache.lucene.util.VectorUtil.toBytesRef;
|
|
||||||
|
|
||||||
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
@ -36,21 +35,23 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
||||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.NumericDocValuesField;
|
import org.apache.lucene.document.NumericDocValuesField;
|
||||||
import org.apache.lucene.document.StoredField;
|
import org.apache.lucene.document.StoredField;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CodecReader;
|
import org.apache.lucene.index.CodecReader;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.IndexWriter;
|
import org.apache.lucene.index.IndexWriter;
|
||||||
import org.apache.lucene.index.IndexWriterConfig;
|
import org.apache.lucene.index.IndexWriterConfig;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
import org.apache.lucene.index.LeafReaderContext;
|
import org.apache.lucene.index.LeafReaderContext;
|
||||||
import org.apache.lucene.index.StoredFields;
|
import org.apache.lucene.index.StoredFields;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.index.VectorValues;
|
import org.apache.lucene.index.VectorValues;
|
||||||
import org.apache.lucene.search.IndexSearcher;
|
import org.apache.lucene.search.IndexSearcher;
|
||||||
import org.apache.lucene.search.KnnVectorQuery;
|
import org.apache.lucene.search.Query;
|
||||||
import org.apache.lucene.search.ScoreDoc;
|
import org.apache.lucene.search.ScoreDoc;
|
||||||
import org.apache.lucene.search.Sort;
|
import org.apache.lucene.search.Sort;
|
||||||
import org.apache.lucene.search.SortField;
|
import org.apache.lucene.search.SortField;
|
||||||
|
@ -65,19 +66,30 @@ import org.apache.lucene.util.FixedBitSet;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
import org.apache.lucene.util.VectorUtil;
|
import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||||
import org.junit.Before;
|
|
||||||
|
|
||||||
/** Tests HNSW KNN graphs */
|
/** Tests HNSW KNN graphs */
|
||||||
public class TestHnswGraph extends LuceneTestCase {
|
abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
||||||
|
|
||||||
VectorSimilarityFunction similarityFunction;
|
VectorSimilarityFunction similarityFunction;
|
||||||
VectorEncoding vectorEncoding;
|
|
||||||
|
|
||||||
@Before
|
abstract VectorEncoding getVectorEncoding();
|
||||||
public void setup() {
|
|
||||||
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
abstract Query knnQuery(String field, T vector, int k);
|
||||||
vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
|
||||||
}
|
abstract T randomVector(int dim);
|
||||||
|
|
||||||
|
abstract AbstractMockVectorValues<T> vectorValues(int size, int dimension);
|
||||||
|
|
||||||
|
abstract AbstractMockVectorValues<T> vectorValues(float[][] values);
|
||||||
|
|
||||||
|
abstract AbstractMockVectorValues<T> vectorValues(LeafReader reader, String fieldName)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction);
|
||||||
|
|
||||||
|
abstract RandomAccessVectorValues<T> circularVectorValues(int nDoc);
|
||||||
|
|
||||||
|
abstract T getTargetVector();
|
||||||
|
|
||||||
// test writing out and reading in a graph gives the expected graph
|
// test writing out and reading in a graph gives the expected graph
|
||||||
public void testReadWrite() throws IOException {
|
public void testReadWrite() throws IOException {
|
||||||
|
@ -86,10 +98,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
int M = random().nextInt(4) + 2;
|
int M = random().nextInt(4) + 2;
|
||||||
int beamWidth = random().nextInt(10) + 5;
|
int beamWidth = random().nextInt(10) + 5;
|
||||||
long seed = random().nextLong();
|
long seed = random().nextLong();
|
||||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, vectorEncoding, random());
|
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
AbstractMockVectorValues<T> v2 = vectors.copy(), v3 = vectors.copy();
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
HnswGraphBuilder.create(
|
||||||
|
vectors, getVectorEncoding(), similarityFunction, M, beamWidth, seed);
|
||||||
HnswGraph hnsw = builder.build(vectors.copy());
|
HnswGraph hnsw = builder.build(vectors.copy());
|
||||||
|
|
||||||
// Recreate the graph while indexing with the same random seed and write it out
|
// Recreate the graph while indexing with the same random seed and write it out
|
||||||
|
@ -115,7 +128,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
indexedDoc++;
|
indexedDoc++;
|
||||||
}
|
}
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new KnnVectorField("field", v2.vectorValue(), similarityFunction));
|
doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction));
|
||||||
doc.add(new StoredField("id", v2.docID()));
|
doc.add(new StoredField("id", v2.docID()));
|
||||||
iw.addDocument(doc);
|
iw.addDocument(doc);
|
||||||
nVec++;
|
nVec++;
|
||||||
|
@ -124,7 +137,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||||
for (LeafReaderContext ctx : reader.leaves()) {
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
VectorValues values = ctx.reader().getVectorValues("field");
|
AbstractMockVectorValues<T> values = vectorValues(ctx.reader(), "field");
|
||||||
assertEquals(dim, values.dimension());
|
assertEquals(dim, values.dimension());
|
||||||
assertEquals(nVec, values.size());
|
assertEquals(nVec, values.size());
|
||||||
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
assertEquals(indexedDoc, ctx.reader().maxDoc());
|
||||||
|
@ -142,15 +155,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private VectorEncoding randomVectorEncoding() {
|
|
||||||
return VectorEncoding.values()[random().nextInt(VectorEncoding.values().length)];
|
|
||||||
}
|
|
||||||
|
|
||||||
// test that sorted index returns the same search results are unsorted
|
// test that sorted index returns the same search results are unsorted
|
||||||
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException {
|
||||||
int dim = random().nextInt(10) + 3;
|
int dim = random().nextInt(10) + 3;
|
||||||
int nDoc = random().nextInt(200) + 100;
|
int nDoc = random().nextInt(200) + 100;
|
||||||
RandomVectorValues vectors = new RandomVectorValues(nDoc, dim, random());
|
AbstractMockVectorValues<T> vectors = vectorValues(nDoc, dim);
|
||||||
|
|
||||||
int M = random().nextInt(10) + 5;
|
int M = random().nextInt(10) + 5;
|
||||||
int beamWidth = random().nextInt(10) + 5;
|
int beamWidth = random().nextInt(10) + 5;
|
||||||
|
@ -190,7 +199,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
indexedDoc++;
|
indexedDoc++;
|
||||||
}
|
}
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new KnnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction));
|
||||||
doc.add(new StoredField("id", vectors.docID()));
|
doc.add(new StoredField("id", vectors.docID()));
|
||||||
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
|
doc.add(new NumericDocValuesField("sortkey", random().nextLong()));
|
||||||
iw.addDocument(doc);
|
iw.addDocument(doc);
|
||||||
|
@ -206,7 +215,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
for (int i = 0; i < 10; i++) {
|
for (int i = 0; i < 10; i++) {
|
||||||
// ask to explore a lot of candidates to ensure the same returned hits,
|
// ask to explore a lot of candidates to ensure the same returned hits,
|
||||||
// as graphs of 2 indices are organized differently
|
// as graphs of 2 indices are organized differently
|
||||||
KnnVectorQuery query = new KnnVectorQuery("vector", randomVector(random(), dim), 50);
|
Query query = knnQuery("vector", randomVector(dim), 50);
|
||||||
List<String> ids1 = new ArrayList<>();
|
List<String> ids1 = new ArrayList<>();
|
||||||
List<Integer> docs1 = new ArrayList<>();
|
List<Integer> docs1 = new ArrayList<>();
|
||||||
List<String> ids2 = new ArrayList<>();
|
List<String> ids2 = new ArrayList<>();
|
||||||
|
@ -241,7 +250,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException {
|
||||||
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
assertEquals("the number of levels in the graphs are different!", g.numLevels(), h.numLevels());
|
||||||
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
assertEquals("the number of nodes in the graphs are different!", g.size(), h.size());
|
||||||
|
|
||||||
|
@ -271,32 +280,32 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
// Make sure we actually approximately find the closest k elements. Mostly this is about
|
// Make sure we actually approximately find the closest k elements. Mostly this is about
|
||||||
// ensuring that we have all the distance functions, comparators, priority queues and so on
|
// ensuring that we have all the distance functions, comparators, priority queues and so on
|
||||||
// oriented in the right directions
|
// oriented in the right directions
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testAknnDiverse() throws IOException {
|
public void testAknnDiverse() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 10, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
// run some searches
|
// run some searches
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
switch (vectorEncoding) {
|
switch (getVectorEncoding()) {
|
||||||
case BYTE -> HnswGraphSearcher.search(
|
case BYTE -> HnswGraphSearcher.search(
|
||||||
getTargetByteVector(),
|
(BytesRef) getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
null,
|
null,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
case FLOAT32 -> HnswGraphSearcher.search(
|
case FLOAT32 -> HnswGraphSearcher.search(
|
||||||
getTargetVector(),
|
(float[]) getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
null,
|
null,
|
||||||
|
@ -323,33 +332,33 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testSearchWithAcceptOrds() throws IOException {
|
public void testSearchWithAcceptOrds() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
vectorEncoding = randomVectorEncoding();
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder<?> builder =
|
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
// the first 10 docs must not be deleted to ensure the expected recall
|
// the first 10 docs must not be deleted to ensure the expected recall
|
||||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
Bits acceptOrds = createRandomAcceptOrds(10, nDoc);
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
switch (vectorEncoding) {
|
switch (getVectorEncoding()) {
|
||||||
case BYTE -> HnswGraphSearcher.search(
|
case BYTE -> HnswGraphSearcher.search(
|
||||||
getTargetByteVector(),
|
(BytesRef) getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
case FLOAT32 -> HnswGraphSearcher.search(
|
case FLOAT32 -> HnswGraphSearcher.search(
|
||||||
getTargetVector(),
|
(float[]) getTargetVector(),
|
||||||
10,
|
10,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
|
@ -367,39 +376,39 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
assertTrue("sum(result docs)=" + sum, sum < 75);
|
assertTrue("sum(result docs)=" + sum, sum < 75);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
public void testSearchWithSelectiveAcceptOrds() throws IOException {
|
||||||
int nDoc = 100;
|
int nDoc = 100;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
// Only mark a few vectors as accepted
|
// Only mark a few vectors as accepted
|
||||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
BitSet acceptOrds = new FixedBitSet(nDoc);
|
||||||
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
|
for (int i = 0; i < nDoc; i += random().nextInt(15, 20)) {
|
||||||
acceptOrds.set(i);
|
acceptOrds.set(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check the search finds all accepted vectors
|
// Check the search finds all accepted vectors
|
||||||
int numAccepted = acceptOrds.cardinality();
|
int numAccepted = acceptOrds.cardinality();
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
switch (vectorEncoding) {
|
switch (getVectorEncoding()) {
|
||||||
case FLOAT32 -> HnswGraphSearcher.search(
|
case FLOAT32 -> HnswGraphSearcher.search(
|
||||||
getTargetVector(),
|
(float[]) getTargetVector(),
|
||||||
numAccepted,
|
numAccepted,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
case BYTE -> HnswGraphSearcher.search(
|
case BYTE -> HnswGraphSearcher.search(
|
||||||
getTargetByteVector(),
|
(BytesRef) getTargetVector(),
|
||||||
numAccepted,
|
numAccepted,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
|
@ -413,81 +422,37 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private float[] getTargetVector() {
|
@SuppressWarnings("unchecked")
|
||||||
return new float[] {1, 0};
|
|
||||||
}
|
|
||||||
|
|
||||||
private BytesRef getTargetByteVector() {
|
|
||||||
return new BytesRef(new byte[] {1, 0});
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
|
||||||
int nDoc = 1000;
|
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
|
||||||
HnswGraphBuilder<?> builder =
|
|
||||||
HnswGraphBuilder.create(
|
|
||||||
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
|
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
|
||||||
|
|
||||||
// Skip over half of the documents that are closest to the query vector
|
|
||||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
|
||||||
for (int i = 500; i < nDoc; i++) {
|
|
||||||
acceptOrds.set(i);
|
|
||||||
}
|
|
||||||
NeighborQueue nn =
|
|
||||||
HnswGraphSearcher.search(
|
|
||||||
getTargetVector(),
|
|
||||||
10,
|
|
||||||
vectors.copy(),
|
|
||||||
VectorEncoding.FLOAT32,
|
|
||||||
similarityFunction,
|
|
||||||
hnsw,
|
|
||||||
acceptOrds,
|
|
||||||
Integer.MAX_VALUE);
|
|
||||||
int[] nodes = nn.nodes();
|
|
||||||
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
|
||||||
int sum = 0;
|
|
||||||
for (int node : nodes) {
|
|
||||||
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
|
||||||
sum += node;
|
|
||||||
}
|
|
||||||
// We still expect to get reasonable recall. The lowest non-skipped docIds
|
|
||||||
// are closest to the query vector: sum(500,509) = 5045
|
|
||||||
assertTrue("sum(result docs)=" + sum, sum < 5100);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testVisitedLimit() throws IOException {
|
public void testVisitedLimit() throws IOException {
|
||||||
int nDoc = 500;
|
int nDoc = 500;
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
CircularVectorValues vectors = new CircularVectorValues(nDoc);
|
RandomAccessVectorValues<T> vectors = circularVectorValues(nDoc);
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
|
|
||||||
int topK = 50;
|
int topK = 50;
|
||||||
int visitedLimit = topK + random().nextInt(5);
|
int visitedLimit = topK + random().nextInt(5);
|
||||||
NeighborQueue nn =
|
NeighborQueue nn =
|
||||||
switch (vectorEncoding) {
|
switch (getVectorEncoding()) {
|
||||||
case FLOAT32 -> HnswGraphSearcher.search(
|
case FLOAT32 -> HnswGraphSearcher.search(
|
||||||
getTargetVector(),
|
(float[]) getTargetVector(),
|
||||||
topK,
|
topK,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<float[]>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
createRandomAcceptOrds(0, vectors.size),
|
createRandomAcceptOrds(0, nDoc),
|
||||||
visitedLimit);
|
visitedLimit);
|
||||||
case BYTE -> HnswGraphSearcher.search(
|
case BYTE -> HnswGraphSearcher.search(
|
||||||
getTargetByteVector(),
|
(BytesRef) getTargetVector(),
|
||||||
topK,
|
topK,
|
||||||
vectors.copy(),
|
(RandomAccessVectorValues<BytesRef>) vectors.copy(),
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
createRandomAcceptOrds(0, vectors.size),
|
createRandomAcceptOrds(0, nDoc),
|
||||||
visitedLimit);
|
visitedLimit);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -504,8 +469,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() ->
|
() ->
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
new RandomVectorValues(1, 1, random()),
|
vectorValues(1, 1),
|
||||||
VectorEncoding.FLOAT32,
|
getVectorEncoding(),
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
0,
|
0,
|
||||||
10,
|
10,
|
||||||
|
@ -515,8 +480,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
IllegalArgumentException.class,
|
IllegalArgumentException.class,
|
||||||
() ->
|
() ->
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
new RandomVectorValues(1, 1, random()),
|
vectorValues(1, 1),
|
||||||
VectorEncoding.FLOAT32,
|
getVectorEncoding(),
|
||||||
VectorSimilarityFunction.EUCLIDEAN,
|
VectorSimilarityFunction.EUCLIDEAN,
|
||||||
10,
|
10,
|
||||||
0,
|
0,
|
||||||
|
@ -530,13 +495,11 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
VectorSimilarityFunction similarityFunction =
|
VectorSimilarityFunction similarityFunction =
|
||||||
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||||
VectorEncoding vectorEncoding = RandomizedTest.randomFrom(VectorEncoding.values());
|
RandomAccessVectorValues<T> vectors = vectorValues(size, dim);
|
||||||
TestHnswGraph.RandomVectorValues vectors =
|
|
||||||
new TestHnswGraph.RandomVectorValues(size, dim, vectorEncoding, random());
|
|
||||||
|
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, M, M * 2, random().nextLong());
|
vectors, getVectorEncoding(), similarityFunction, M, M * 2, random().nextLong());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
|
long estimated = RamUsageEstimator.sizeOfObject(hnsw);
|
||||||
long actual = ramUsed(hnsw);
|
long actual = ramUsed(hnsw);
|
||||||
|
@ -546,7 +509,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
public void testDiversity() throws IOException {
|
public void testDiversity() throws IOException {
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||||
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
// Some carefully checked test cases with simple 2d vectors on the unit circle:
|
||||||
float[][] values = {
|
float[][] values = {
|
||||||
|
@ -558,21 +520,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
unitVector2d(0.77),
|
unitVector2d(0.77),
|
||||||
unitVector2d(0.6)
|
unitVector2d(0.6)
|
||||||
};
|
};
|
||||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||||
for (float[] v : values) {
|
|
||||||
for (int i = 0; i < v.length; i++) {
|
|
||||||
v[i] *= 127;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MockVectorValues vectors = new MockVectorValues(values);
|
|
||||||
// First add nodes until everybody gets a full neighbor list
|
// First add nodes until everybody gets a full neighbor list
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 2, 10, random().nextInt());
|
||||||
// node 0 is added by the builder constructor
|
// node 0 is added by the builder constructor
|
||||||
// builder.addGraphNode(vectors.vectorValue(0));
|
// builder.addGraphNode(vectors.vectorValue(0));
|
||||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||||
builder.addGraphNode(1, vectorsCopy);
|
builder.addGraphNode(1, vectorsCopy);
|
||||||
builder.addGraphNode(2, vectorsCopy);
|
builder.addGraphNode(2, vectorsCopy);
|
||||||
// now every node has tried to attach every other node as a neighbor, but
|
// now every node has tried to attach every other node as a neighbor, but
|
||||||
|
@ -609,7 +564,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testDiversityFallback() throws IOException {
|
public void testDiversityFallback() throws IOException {
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
// Some test cases can't be exercised in two dimensions;
|
// Some test cases can't be exercised in two dimensions;
|
||||||
// in particular if a new neighbor displaces an existing neighbor
|
// in particular if a new neighbor displaces an existing neighbor
|
||||||
|
@ -622,14 +576,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
{10, 0, 0},
|
{10, 0, 0},
|
||||||
{0, 4, 0}
|
{0, 4, 0}
|
||||||
};
|
};
|
||||||
MockVectorValues vectors = new MockVectorValues(values);
|
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||||
// First add nodes until everybody gets a full neighbor list
|
// First add nodes until everybody gets a full neighbor list
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||||
// node 0 is added by the builder constructor
|
// node 0 is added by the builder constructor
|
||||||
// builder.addGraphNode(vectors.vectorValue(0));
|
// builder.addGraphNode(vectors.vectorValue(0));
|
||||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||||
builder.addGraphNode(1, vectorsCopy);
|
builder.addGraphNode(1, vectorsCopy);
|
||||||
builder.addGraphNode(2, vectorsCopy);
|
builder.addGraphNode(2, vectorsCopy);
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
|
@ -647,7 +601,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testDiversity3d() throws IOException {
|
public void testDiversity3d() throws IOException {
|
||||||
vectorEncoding = randomVectorEncoding();
|
|
||||||
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
|
// test the case when a neighbor *becomes* non-diverse when a newer better neighbor arrives
|
||||||
float[][] values = {
|
float[][] values = {
|
||||||
|
@ -656,14 +609,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
{0, 0, 20},
|
{0, 0, 20},
|
||||||
{0, 9, 0}
|
{0, 9, 0}
|
||||||
};
|
};
|
||||||
MockVectorValues vectors = new MockVectorValues(values);
|
AbstractMockVectorValues<T> vectors = vectorValues(values);
|
||||||
// First add nodes until everybody gets a full neighbor list
|
// First add nodes until everybody gets a full neighbor list
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 1, 10, random().nextInt());
|
vectors, getVectorEncoding(), similarityFunction, 1, 10, random().nextInt());
|
||||||
// node 0 is added by the builder constructor
|
// node 0 is added by the builder constructor
|
||||||
// builder.addGraphNode(vectors.vectorValue(0));
|
// builder.addGraphNode(vectors.vectorValue(0));
|
||||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
RandomAccessVectorValues<T> vectorsCopy = vectors.copy();
|
||||||
builder.addGraphNode(1, vectorsCopy);
|
builder.addGraphNode(1, vectorsCopy);
|
||||||
builder.addGraphNode(2, vectorsCopy);
|
builder.addGraphNode(2, vectorsCopy);
|
||||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||||
|
@ -691,44 +644,38 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
actual);
|
actual);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public void testRandom() throws IOException {
|
public void testRandom() throws IOException {
|
||||||
int size = atLeast(100);
|
int size = atLeast(100);
|
||||||
int dim = atLeast(10);
|
int dim = atLeast(10);
|
||||||
RandomVectorValues vectors = new RandomVectorValues(size, dim, vectorEncoding, random());
|
AbstractMockVectorValues<T> vectors = vectorValues(size, dim);
|
||||||
int topK = 5;
|
int topK = 5;
|
||||||
HnswGraphBuilder<?> builder =
|
HnswGraphBuilder<T> builder =
|
||||||
HnswGraphBuilder.create(
|
HnswGraphBuilder.create(
|
||||||
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
|
vectors, getVectorEncoding(), similarityFunction, 10, 30, random().nextLong());
|
||||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||||
|
|
||||||
int totalMatches = 0;
|
int totalMatches = 0;
|
||||||
for (int i = 0; i < 100; i++) {
|
for (int i = 0; i < 100; i++) {
|
||||||
NeighborQueue actual;
|
NeighborQueue actual;
|
||||||
float[] query;
|
T query = randomVector(dim);
|
||||||
BytesRef bQuery = null;
|
|
||||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
|
||||||
query = randomVector8(random(), dim);
|
|
||||||
bQuery = toBytesRef(query);
|
|
||||||
} else {
|
|
||||||
query = randomVector(random(), dim);
|
|
||||||
}
|
|
||||||
actual =
|
actual =
|
||||||
switch (vectorEncoding) {
|
switch (getVectorEncoding()) {
|
||||||
case BYTE -> HnswGraphSearcher.search(
|
case BYTE -> HnswGraphSearcher.search(
|
||||||
bQuery,
|
(BytesRef) query,
|
||||||
100,
|
100,
|
||||||
vectors,
|
(RandomAccessVectorValues<BytesRef>) vectors,
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
Integer.MAX_VALUE);
|
Integer.MAX_VALUE);
|
||||||
case FLOAT32 -> HnswGraphSearcher.search(
|
case FLOAT32 -> HnswGraphSearcher.search(
|
||||||
query,
|
(float[]) query,
|
||||||
100,
|
100,
|
||||||
vectors,
|
(RandomAccessVectorValues<float[]>) vectors,
|
||||||
vectorEncoding,
|
getVectorEncoding(),
|
||||||
similarityFunction,
|
similarityFunction,
|
||||||
hnsw,
|
hnsw,
|
||||||
acceptOrds,
|
acceptOrds,
|
||||||
|
@ -741,10 +688,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
NeighborQueue expected = new NeighborQueue(topK, false);
|
NeighborQueue expected = new NeighborQueue(topK, false);
|
||||||
for (int j = 0; j < size; j++) {
|
for (int j = 0; j < size; j++) {
|
||||||
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) {
|
||||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
if (getVectorEncoding() == VectorEncoding.BYTE) {
|
||||||
expected.add(j, similarityFunction.compare(bQuery, vectors.binaryValue(j)));
|
assert query instanceof BytesRef;
|
||||||
|
expected.add(
|
||||||
|
j, similarityFunction.compare((BytesRef) query, (BytesRef) vectors.vectorValue(j)));
|
||||||
} else {
|
} else {
|
||||||
expected.add(j, similarityFunction.compare(query, vectors.vectorValue(j)));
|
assert query instanceof float[];
|
||||||
|
expected.add(
|
||||||
|
j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j)));
|
||||||
}
|
}
|
||||||
if (expected.size() > topK) {
|
if (expected.size() > topK) {
|
||||||
expected.pop();
|
expected.pop();
|
||||||
|
@ -778,17 +729,16 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||||
static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues {
|
static class CircularVectorValues extends VectorValues
|
||||||
|
implements RandomAccessVectorValues<float[]> {
|
||||||
private final int size;
|
private final int size;
|
||||||
private final float[] value;
|
private final float[] value;
|
||||||
private final BytesRef binaryValue;
|
|
||||||
|
|
||||||
int doc = -1;
|
int doc = -1;
|
||||||
|
|
||||||
CircularVectorValues(int size) {
|
CircularVectorValues(int size) {
|
||||||
this.size = size;
|
this.size = size;
|
||||||
value = new float[2];
|
value = new float[2];
|
||||||
binaryValue = new BytesRef(new byte[2]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -835,14 +785,70 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
public float[] vectorValue(int ord) {
|
public float[] vectorValue(int ord) {
|
||||||
return unitVector2d(ord / (double) size, value);
|
return unitVector2d(ord / (double) size, value);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||||
|
static class CircularByteVectorValues extends ByteVectorValues
|
||||||
|
implements RandomAccessVectorValues<BytesRef> {
|
||||||
|
private final int size;
|
||||||
|
private final float[] value;
|
||||||
|
private final BytesRef bValue;
|
||||||
|
|
||||||
|
int doc = -1;
|
||||||
|
|
||||||
|
CircularByteVectorValues(int size) {
|
||||||
|
this.size = size;
|
||||||
|
value = new float[2];
|
||||||
|
bValue = new BytesRef(new byte[2]);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public BytesRef binaryValue(int ord) {
|
public CircularByteVectorValues copy() {
|
||||||
float[] vectorValue = vectorValue(ord);
|
return new CircularByteVectorValues(size);
|
||||||
for (int i = 0; i < vectorValue.length; i++) {
|
|
||||||
binaryValue.bytes[i] = (byte) (vectorValue[i] * 127);
|
|
||||||
}
|
}
|
||||||
return binaryValue;
|
|
||||||
|
@Override
|
||||||
|
public int dimension() {
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int size() {
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() {
|
||||||
|
return vectorValue(doc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int docID() {
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int nextDoc() {
|
||||||
|
return advance(doc + 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int advance(int target) {
|
||||||
|
if (target >= 0 && target < size) {
|
||||||
|
doc = target;
|
||||||
|
} else {
|
||||||
|
doc = NO_MORE_DOCS;
|
||||||
|
}
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue(int ord) {
|
||||||
|
unitVector2d(ord / (double) size, value);
|
||||||
|
for (int i = 0; i < value.length; i++) {
|
||||||
|
bValue.bytes[i] = (byte) (value[i] * 127);
|
||||||
|
}
|
||||||
|
return bValue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -864,7 +870,8 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
return neighbors;
|
return neighbors;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void assertVectorsEqual(VectorValues u, VectorValues v) throws IOException {
|
void assertVectorsEqual(AbstractMockVectorValues<T> u, AbstractMockVectorValues<T> v)
|
||||||
|
throws IOException {
|
||||||
int uDoc, vDoc;
|
int uDoc, vDoc;
|
||||||
while (true) {
|
while (true) {
|
||||||
uDoc = u.nextDoc();
|
uDoc = u.nextDoc();
|
||||||
|
@ -873,49 +880,40 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
if (uDoc == NO_MORE_DOCS) {
|
if (uDoc == NO_MORE_DOCS) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
float delta = vectorEncoding == VectorEncoding.BYTE ? 1 : 1e-4f;
|
switch (getVectorEncoding()) {
|
||||||
|
case BYTE:
|
||||||
assertArrayEquals(
|
assertArrayEquals(
|
||||||
"vectors do not match for doc=" + uDoc, u.vectorValue(), v.vectorValue(), delta);
|
"vectors do not match for doc=" + uDoc,
|
||||||
|
((BytesRef) u.vectorValue()).bytes,
|
||||||
|
((BytesRef) v.vectorValue()).bytes);
|
||||||
|
break;
|
||||||
|
case FLOAT32:
|
||||||
|
assertArrayEquals(
|
||||||
|
"vectors do not match for doc=" + uDoc,
|
||||||
|
(float[]) u.vectorValue(),
|
||||||
|
(float[]) v.vectorValue(),
|
||||||
|
1e-4f);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Produces random vectors and caches them for random-access. */
|
static float[][] createRandomFloatVectors(int size, int dimension, Random random) {
|
||||||
static class RandomVectorValues extends MockVectorValues {
|
|
||||||
|
|
||||||
RandomVectorValues(int size, int dimension, Random random) {
|
|
||||||
super(createRandomVectors(size, dimension, null, random));
|
|
||||||
}
|
|
||||||
|
|
||||||
RandomVectorValues(int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
|
||||||
super(createRandomVectors(size, dimension, vectorEncoding, random));
|
|
||||||
}
|
|
||||||
|
|
||||||
RandomVectorValues(RandomVectorValues other) {
|
|
||||||
super(other.values);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RandomVectorValues copy() {
|
|
||||||
return new RandomVectorValues(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static float[][] createRandomVectors(
|
|
||||||
int size, int dimension, VectorEncoding vectorEncoding, Random random) {
|
|
||||||
float[][] vectors = new float[size][];
|
float[][] vectors = new float[size][];
|
||||||
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||||
vectors[offset] = randomVector(random, dimension);
|
vectors[offset] = randomVector(random, dimension);
|
||||||
}
|
}
|
||||||
if (vectorEncoding == VectorEncoding.BYTE) {
|
|
||||||
for (float[] vector : vectors) {
|
|
||||||
if (vector != null) {
|
|
||||||
for (int i = 0; i < vector.length; i++) {
|
|
||||||
vector[i] = (byte) (127 * vector[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return vectors;
|
return vectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static byte[][] createRandomByteVectors(int size, int dimension, Random random) {
|
||||||
|
byte[][] vectors = new byte[size][];
|
||||||
|
for (int offset = 0; offset < size; offset += random.nextInt(3) + 1) {
|
||||||
|
vectors[offset] = randomVector8(random, dimension);
|
||||||
|
}
|
||||||
|
return vectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -937,7 +935,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
return bits;
|
return bits;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float[] randomVector(Random random, int dim) {
|
static float[] randomVector(Random random, int dim) {
|
||||||
float[] vec = new float[dim];
|
float[] vec = new float[dim];
|
||||||
for (int i = 0; i < dim; i++) {
|
for (int i = 0; i < dim; i++) {
|
||||||
vec[i] = random.nextFloat();
|
vec[i] = random.nextFloat();
|
||||||
|
@ -949,11 +947,12 @@ public class TestHnswGraph extends LuceneTestCase {
|
||||||
return vec;
|
return vec;
|
||||||
}
|
}
|
||||||
|
|
||||||
private static float[] randomVector8(Random random, int dim) {
|
static byte[] randomVector8(Random random, int dim) {
|
||||||
float[] fvec = randomVector(random, dim);
|
float[] fvec = randomVector(random, dim);
|
||||||
|
byte[] bvec = new byte[dim];
|
||||||
for (int i = 0; i < dim; i++) {
|
for (int i = 0; i < dim; i++) {
|
||||||
fvec[i] *= 127;
|
bvec[i] = (byte) (fvec[i] * 127);
|
||||||
}
|
}
|
||||||
return fvec;
|
return bvec;
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -45,6 +45,7 @@ import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
||||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.document.StoredField;
|
import org.apache.lucene.document.StoredField;
|
||||||
import org.apache.lucene.index.CodecReader;
|
import org.apache.lucene.index.CodecReader;
|
||||||
|
@ -704,7 +705,11 @@ public class KnnGraphTester {
|
||||||
iwc.setUseCompoundFile(false);
|
iwc.setUseCompoundFile(false);
|
||||||
// iwc.setMaxBufferedDocs(10000);
|
// iwc.setMaxBufferedDocs(10000);
|
||||||
|
|
||||||
FieldType fieldType = KnnVectorField.createFieldType(dim, vectorEncoding, similarityFunction);
|
FieldType fieldType =
|
||||||
|
switch (vectorEncoding) {
|
||||||
|
case BYTE -> KnnByteVectorField.createFieldType(dim, similarityFunction);
|
||||||
|
case FLOAT32 -> KnnVectorField.createFieldType(dim, similarityFunction);
|
||||||
|
};
|
||||||
if (quiet == false) {
|
if (quiet == false) {
|
||||||
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
iwc.setInfoStream(new PrintStreamInfoStream(System.out));
|
||||||
System.out.println("creating index in " + indexPath);
|
System.out.println("creating index in " + indexPath);
|
||||||
|
@ -718,7 +723,7 @@ public class KnnGraphTester {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
switch (vectorEncoding) {
|
switch (vectorEncoding) {
|
||||||
case BYTE -> doc.add(
|
case BYTE -> doc.add(
|
||||||
new KnnVectorField(
|
new KnnByteVectorField(
|
||||||
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
|
KNN_FIELD, ((VectorReaderByte) vectorReader).nextBytes(), fieldType));
|
||||||
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
|
case FLOAT32 -> doc.add(new KnnVectorField(KNN_FIELD, vectorReader.next(), fieldType));
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,70 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
|
||||||
|
class MockByteVectorValues extends AbstractMockVectorValues<BytesRef> {
|
||||||
|
private final byte[] scratch;
|
||||||
|
|
||||||
|
static MockByteVectorValues fromValues(byte[][] byteValues) {
|
||||||
|
int dimension = byteValues[0].length;
|
||||||
|
BytesRef[] values = new BytesRef[byteValues.length];
|
||||||
|
for (int i = 0; i < byteValues.length; i++) {
|
||||||
|
values[i] = byteValues[i] == null ? null : new BytesRef(byteValues[i]);
|
||||||
|
}
|
||||||
|
BytesRef[] denseValues = new BytesRef[values.length];
|
||||||
|
int count = 0;
|
||||||
|
for (int i = 0; i < byteValues.length; i++) {
|
||||||
|
if (values[i] != null) {
|
||||||
|
denseValues[count++] = values[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new MockByteVectorValues(values, dimension, denseValues, count);
|
||||||
|
}
|
||||||
|
|
||||||
|
MockByteVectorValues(BytesRef[] values, int dimension, BytesRef[] denseValues, int numVectors) {
|
||||||
|
super(values, dimension, denseValues, numVectors);
|
||||||
|
scratch = new byte[dimension];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public MockByteVectorValues copy() {
|
||||||
|
return new MockByteVectorValues(
|
||||||
|
ArrayUtil.copyOfSubArray(values, 0, values.length),
|
||||||
|
dimension,
|
||||||
|
ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
|
||||||
|
numVectors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public BytesRef vectorValue() {
|
||||||
|
if (LuceneTestCase.random().nextBoolean()) {
|
||||||
|
return values[pos];
|
||||||
|
} else {
|
||||||
|
// Sometimes use the same scratch array repeatedly, mimicing what the codec will do.
|
||||||
|
// This should help us catch cases of aliasing where the same ByteVectorValues source is used
|
||||||
|
// twice in a
|
||||||
|
// single computation.
|
||||||
|
System.arraycopy(values[pos].bytes, values[pos].offset, scratch, 0, dimension);
|
||||||
|
return new BytesRef(scratch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,52 +17,37 @@
|
||||||
|
|
||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import org.apache.lucene.index.VectorValues;
|
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
import org.apache.lucene.util.BytesRef;
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
|
||||||
class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
|
class MockVectorValues extends AbstractMockVectorValues<float[]> {
|
||||||
private final float[] scratch;
|
private final float[] scratch;
|
||||||
|
|
||||||
protected final int dimension;
|
static MockVectorValues fromValues(float[][] values) {
|
||||||
protected final float[][] denseValues;
|
int dimension = values[0].length;
|
||||||
protected final float[][] values;
|
|
||||||
private final int numVectors;
|
|
||||||
private final BytesRef binaryValue;
|
|
||||||
|
|
||||||
private int pos = -1;
|
|
||||||
|
|
||||||
MockVectorValues(float[][] values) {
|
|
||||||
this.dimension = values[0].length;
|
|
||||||
this.values = values;
|
|
||||||
int maxDoc = values.length;
|
int maxDoc = values.length;
|
||||||
denseValues = new float[maxDoc][];
|
float[][] denseValues = new float[maxDoc][];
|
||||||
int count = 0;
|
int count = 0;
|
||||||
for (int i = 0; i < maxDoc; i++) {
|
for (int i = 0; i < maxDoc; i++) {
|
||||||
if (values[i] != null) {
|
if (values[i] != null) {
|
||||||
denseValues[count++] = values[i];
|
denseValues[count++] = values[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
numVectors = count;
|
return new MockVectorValues(values, dimension, denseValues, count);
|
||||||
scratch = new float[dimension];
|
}
|
||||||
// used by tests that build a graph from bytes rather than floats
|
|
||||||
binaryValue = new BytesRef(dimension);
|
MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) {
|
||||||
binaryValue.length = dimension;
|
super(values, dimension, denseValues, numVectors);
|
||||||
|
this.scratch = new float[dimension];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public MockVectorValues copy() {
|
public MockVectorValues copy() {
|
||||||
return new MockVectorValues(values);
|
return new MockVectorValues(
|
||||||
}
|
ArrayUtil.copyOfSubArray(values, 0, values.length),
|
||||||
|
dimension,
|
||||||
@Override
|
ArrayUtil.copyOfSubArray(denseValues, 0, denseValues.length),
|
||||||
public int size() {
|
numVectors);
|
||||||
return numVectors;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int dimension() {
|
|
||||||
return dimension;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -83,42 +68,4 @@ class MockVectorValues extends VectorValues implements RandomAccessVectorValues
|
||||||
public float[] vectorValue(int targetOrd) {
|
public float[] vectorValue(int targetOrd) {
|
||||||
return denseValues[targetOrd];
|
return denseValues[targetOrd];
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public BytesRef binaryValue(int targetOrd) {
|
|
||||||
float[] value = vectorValue(targetOrd);
|
|
||||||
for (int i = 0; i < value.length; i++) {
|
|
||||||
binaryValue.bytes[i] = (byte) value[i];
|
|
||||||
}
|
|
||||||
return binaryValue;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean seek(int target) {
|
|
||||||
if (target >= 0 && target < values.length && values[target] != null) {
|
|
||||||
pos = target;
|
|
||||||
return true;
|
|
||||||
} else {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int docID() {
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int nextDoc() {
|
|
||||||
return advance(pos + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int advance(int target) {
|
|
||||||
while (++pos < values.length) {
|
|
||||||
if (seek(pos)) {
|
|
||||||
return pos;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return NO_MORE_DOCS;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,118 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.search.KnnByteVectorQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.apache.lucene.util.BytesRef;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
/** Tests HNSW KNN graphs */
|
||||||
|
public class TestHnswByteVectorGraph extends HnswGraphTestCase<BytesRef> {
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
VectorEncoding getVectorEncoding() {
|
||||||
|
return VectorEncoding.BYTE;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Query knnQuery(String field, BytesRef vector, int k) {
|
||||||
|
return new KnnByteVectorQuery(field, vector, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
BytesRef randomVector(int dim) {
|
||||||
|
return new BytesRef(randomVector8(random(), dim));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<BytesRef> vectorValues(int size, int dimension) {
|
||||||
|
return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random()));
|
||||||
|
}
|
||||||
|
|
||||||
|
static boolean fitsInByte(float v) {
|
||||||
|
return v <= 127 && v >= -128 && v % 1 == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<BytesRef> vectorValues(float[][] values) {
|
||||||
|
byte[][] bValues = new byte[values.length][];
|
||||||
|
// The case when all floats fit within a byte already.
|
||||||
|
boolean scaleSimple = fitsInByte(values[0][0]);
|
||||||
|
for (int i = 0; i < values.length; i++) {
|
||||||
|
bValues[i] = new byte[values[i].length];
|
||||||
|
for (int j = 0; j < values[i].length; j++) {
|
||||||
|
final float v;
|
||||||
|
if (scaleSimple) {
|
||||||
|
assert fitsInByte(values[i][j]);
|
||||||
|
v = values[i][j];
|
||||||
|
} else {
|
||||||
|
v = values[i][j] * 127;
|
||||||
|
}
|
||||||
|
bValues[i][j] = (byte) v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return MockByteVectorValues.fromValues(bValues);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<BytesRef> vectorValues(LeafReader reader, String fieldName)
|
||||||
|
throws IOException {
|
||||||
|
ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName);
|
||||||
|
byte[][] vectors = new byte[reader.maxDoc()][];
|
||||||
|
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
vectors[vectorValues.docID()] =
|
||||||
|
ArrayUtil.copyOfSubArray(
|
||||||
|
vectorValues.vectorValue().bytes,
|
||||||
|
vectorValues.vectorValue().offset,
|
||||||
|
vectorValues.vectorValue().offset + vectorValues.vectorValue().length);
|
||||||
|
}
|
||||||
|
return MockByteVectorValues.fromValues(vectors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Field knnVectorField(String name, BytesRef vector, VectorSimilarityFunction similarityFunction) {
|
||||||
|
return new KnnByteVectorField(name, vector, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
RandomAccessVectorValues<BytesRef> circularVectorValues(int nDoc) {
|
||||||
|
return new CircularByteVectorValues(nDoc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
BytesRef getTargetVector() {
|
||||||
|
return new BytesRef(new byte[] {1, 0});
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,133 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
import com.carrotsearch.randomizedtesting.RandomizedTest;
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.document.Field;
|
||||||
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
|
import org.apache.lucene.index.LeafReader;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.index.VectorValues;
|
||||||
|
import org.apache.lucene.search.KnnVectorQuery;
|
||||||
|
import org.apache.lucene.search.Query;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.apache.lucene.util.FixedBitSet;
|
||||||
|
import org.junit.Before;
|
||||||
|
|
||||||
|
/** Tests HNSW KNN graphs */
|
||||||
|
public class TestHnswFloatVectorGraph extends HnswGraphTestCase<float[]> {
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
VectorEncoding getVectorEncoding() {
|
||||||
|
return VectorEncoding.FLOAT32;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Query knnQuery(String field, float[] vector, int k) {
|
||||||
|
return new KnnVectorQuery(field, vector, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
float[] randomVector(int dim) {
|
||||||
|
return randomVector(random(), dim);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<float[]> vectorValues(int size, int dimension) {
|
||||||
|
return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random()));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<float[]> vectorValues(float[][] values) {
|
||||||
|
return MockVectorValues.fromValues(values);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
AbstractMockVectorValues<float[]> vectorValues(LeafReader reader, String fieldName)
|
||||||
|
throws IOException {
|
||||||
|
VectorValues vectorValues = reader.getVectorValues(fieldName);
|
||||||
|
float[][] vectors = new float[reader.maxDoc()][];
|
||||||
|
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
vectors[vectorValues.docID()] =
|
||||||
|
ArrayUtil.copyOfSubArray(
|
||||||
|
vectorValues.vectorValue(), 0, vectorValues.vectorValue().length);
|
||||||
|
}
|
||||||
|
return MockVectorValues.fromValues(vectors);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
Field knnVectorField(String name, float[] vector, VectorSimilarityFunction similarityFunction) {
|
||||||
|
return new KnnVectorField(name, vector, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
RandomAccessVectorValues<float[]> circularVectorValues(int nDoc) {
|
||||||
|
return new CircularVectorValues(nDoc);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
float[] getTargetVector() {
|
||||||
|
return new float[] {1f, 0f};
|
||||||
|
}
|
||||||
|
|
||||||
|
public void testSearchWithSkewedAcceptOrds() throws IOException {
|
||||||
|
int nDoc = 1000;
|
||||||
|
similarityFunction = VectorSimilarityFunction.EUCLIDEAN;
|
||||||
|
RandomAccessVectorValues<float[]> vectors = circularVectorValues(nDoc);
|
||||||
|
HnswGraphBuilder<float[]> builder =
|
||||||
|
HnswGraphBuilder.create(
|
||||||
|
vectors, getVectorEncoding(), similarityFunction, 16, 100, random().nextInt());
|
||||||
|
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||||
|
|
||||||
|
// Skip over half of the documents that are closest to the query vector
|
||||||
|
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||||
|
for (int i = 500; i < nDoc; i++) {
|
||||||
|
acceptOrds.set(i);
|
||||||
|
}
|
||||||
|
NeighborQueue nn =
|
||||||
|
HnswGraphSearcher.search(
|
||||||
|
getTargetVector(),
|
||||||
|
10,
|
||||||
|
vectors.copy(),
|
||||||
|
getVectorEncoding(),
|
||||||
|
similarityFunction,
|
||||||
|
hnsw,
|
||||||
|
acceptOrds,
|
||||||
|
Integer.MAX_VALUE);
|
||||||
|
|
||||||
|
int[] nodes = nn.nodes();
|
||||||
|
assertEquals("Number of found results is not equal to [10].", 10, nodes.length);
|
||||||
|
int sum = 0;
|
||||||
|
for (int node : nodes) {
|
||||||
|
assertTrue("the results include a deleted document: " + node, acceptOrds.get(node));
|
||||||
|
sum += node;
|
||||||
|
}
|
||||||
|
// We still expect to get reasonable recall. The lowest non-skipped docIds
|
||||||
|
// are closest to the query vector: sum(500,509) = 5045
|
||||||
|
assertTrue("sum(result docs)=" + sum, sum < 5100);
|
||||||
|
}
|
||||||
|
}
|
|
@ -20,6 +20,7 @@ import java.io.IOException;
|
||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Iterator;
|
import java.util.Iterator;
|
||||||
import org.apache.lucene.index.BinaryDocValues;
|
import org.apache.lucene.index.BinaryDocValues;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.DocValuesType;
|
import org.apache.lucene.index.DocValuesType;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
@ -165,6 +166,11 @@ public class TermVectorLeafReader extends LeafReader {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String fieldName) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
|
|
@ -1395,6 +1395,11 @@ public class MemoryIndex {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String fieldName) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
|
@ -113,7 +114,9 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
@Override
|
@Override
|
||||||
public VectorValues getVectorValues(String field) throws IOException {
|
public VectorValues getVectorValues(String field) throws IOException {
|
||||||
FieldInfo fi = fis.fieldInfo(field);
|
FieldInfo fi = fis.fieldInfo(field);
|
||||||
assert fi != null && fi.getVectorDimension() > 0;
|
assert fi != null
|
||||||
|
&& fi.getVectorDimension() > 0
|
||||||
|
&& fi.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||||
VectorValues values = delegate.getVectorValues(field);
|
VectorValues values = delegate.getVectorValues(field);
|
||||||
assert values != null;
|
assert values != null;
|
||||||
assert values.docID() == -1;
|
assert values.docID() == -1;
|
||||||
|
@ -122,6 +125,20 @@ public class AssertingKnnVectorsFormat extends KnnVectorsFormat {
|
||||||
return values;
|
return values;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
FieldInfo fi = fis.fieldInfo(field);
|
||||||
|
assert fi != null
|
||||||
|
&& fi.getVectorDimension() > 0
|
||||||
|
&& fi.getVectorEncoding() == VectorEncoding.BYTE;
|
||||||
|
ByteVectorValues values = delegate.getByteVectorValues(field);
|
||||||
|
assert values != null;
|
||||||
|
assert values.docID() == -1;
|
||||||
|
assert values.size() >= 0;
|
||||||
|
assert values.dimension() > 0;
|
||||||
|
return values;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
public TopDocs search(String field, float[] target, int k, Bits acceptDocs, int visitedLimit)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
|
|
@ -28,10 +28,12 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
import org.apache.lucene.document.FieldType;
|
import org.apache.lucene.document.FieldType;
|
||||||
|
import org.apache.lucene.document.KnnByteVectorField;
|
||||||
import org.apache.lucene.document.KnnVectorField;
|
import org.apache.lucene.document.KnnVectorField;
|
||||||
import org.apache.lucene.document.NumericDocValuesField;
|
import org.apache.lucene.document.NumericDocValuesField;
|
||||||
import org.apache.lucene.document.StoredField;
|
import org.apache.lucene.document.StoredField;
|
||||||
import org.apache.lucene.document.StringField;
|
import org.apache.lucene.document.StringField;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CheckIndex;
|
import org.apache.lucene.index.CheckIndex;
|
||||||
import org.apache.lucene.index.CodecReader;
|
import org.apache.lucene.index.CodecReader;
|
||||||
import org.apache.lucene.index.DirectoryReader;
|
import org.apache.lucene.index.DirectoryReader;
|
||||||
|
@ -79,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
@Override
|
@Override
|
||||||
protected void addRandomFields(Document doc) {
|
protected void addRandomFields(Document doc) {
|
||||||
switch (vectorEncoding) {
|
switch (vectorEncoding) {
|
||||||
case BYTE -> doc.add(new KnnVectorField("v2", randomVector8(30), similarityFunction));
|
case BYTE -> doc.add(
|
||||||
|
new KnnByteVectorField("v2", new BytesRef(randomVector8(30)), similarityFunction));
|
||||||
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
case FLOAT32 -> doc.add(new KnnVectorField("v2", randomVector(30), similarityFunction));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -628,9 +631,11 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
if (random().nextInt(100) == 17) {
|
if (random().nextInt(100) == 17) {
|
||||||
switch (fieldVectorEncodings[field]) {
|
switch (fieldVectorEncodings[field]) {
|
||||||
case BYTE -> {
|
case BYTE -> {
|
||||||
BytesRef b = randomVector8(fieldDims[field]);
|
byte[] b = randomVector8(fieldDims[field]);
|
||||||
doc.add(new KnnVectorField(fieldName, b, fieldSimilarityFunctions[field]));
|
doc.add(
|
||||||
fieldTotals[field] += b.bytes[b.offset];
|
new KnnByteVectorField(
|
||||||
|
fieldName, new BytesRef(b), fieldSimilarityFunctions[field]));
|
||||||
|
fieldTotals[field] += b[0];
|
||||||
}
|
}
|
||||||
case FLOAT32 -> {
|
case FLOAT32 -> {
|
||||||
float[] v = randomVector(fieldDims[field]);
|
float[] v = randomVector(fieldDims[field]);
|
||||||
|
@ -648,12 +653,27 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
int docCount = 0;
|
int docCount = 0;
|
||||||
double checksum = 0;
|
double checksum = 0;
|
||||||
String fieldName = "int" + field;
|
String fieldName = "int" + field;
|
||||||
|
switch (fieldVectorEncodings[field]) {
|
||||||
|
case BYTE -> {
|
||||||
for (LeafReaderContext ctx : r.leaves()) {
|
for (LeafReaderContext ctx : r.leaves()) {
|
||||||
VectorValues vectors = ctx.reader().getVectorValues(fieldName);
|
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||||
if (vectors != null) {
|
if (byteVectorValues != null) {
|
||||||
docCount += vectors.size();
|
docCount += byteVectorValues.size();
|
||||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
checksum += vectors.vectorValue()[0];
|
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
for (LeafReaderContext ctx : r.leaves()) {
|
||||||
|
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
||||||
|
if (vectorValues != null) {
|
||||||
|
docCount += vectorValues.size();
|
||||||
|
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
checksum += vectorValues.vectorValue()[0];
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -755,15 +775,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
LeafReader leaf = getOnlyLeafReader(reader);
|
LeafReader leaf = getOnlyLeafReader(reader);
|
||||||
|
|
||||||
StoredFields storedFields = leaf.storedFields();
|
StoredFields storedFields = leaf.storedFields();
|
||||||
VectorValues vectorValues = leaf.getVectorValues(fieldName);
|
ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName);
|
||||||
assertEquals(2, vectorValues.dimension());
|
assertEquals(2, vectorValues.dimension());
|
||||||
assertEquals(3, vectorValues.size());
|
assertEquals(3, vectorValues.size());
|
||||||
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||||
assertEquals(-1f, vectorValues.vectorValue()[0], 0);
|
assertEquals(-1, vectorValues.vectorValue().bytes[0], 0);
|
||||||
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||||
assertEquals(1, vectorValues.vectorValue()[0], 0);
|
assertEquals(1, vectorValues.vectorValue().bytes[0], 0);
|
||||||
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id"));
|
||||||
assertEquals(0, vectorValues.vectorValue()[0], 0);
|
assertEquals(0, vectorValues.vectorValue().bytes[0], 0);
|
||||||
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
assertEquals(NO_MORE_DOCS, vectorValues.nextDoc());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -915,7 +935,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
if (random().nextInt(7) != 3) {
|
if (random().nextInt(7) != 3) {
|
||||||
// usually index a vector value for a doc
|
// usually index a vector value for a doc
|
||||||
values[i] = randomVector8(dimension);
|
values[i] = new BytesRef(randomVector8(dimension));
|
||||||
++numValues;
|
++numValues;
|
||||||
}
|
}
|
||||||
if (random().nextBoolean() && values[i] != null) {
|
if (random().nextBoolean() && values[i] != null) {
|
||||||
|
@ -943,7 +963,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
try (IndexReader reader = DirectoryReader.open(iw)) {
|
try (IndexReader reader = DirectoryReader.open(iw)) {
|
||||||
int valueCount = 0, totalSize = 0;
|
int valueCount = 0, totalSize = 0;
|
||||||
for (LeafReaderContext ctx : reader.leaves()) {
|
for (LeafReaderContext ctx : reader.leaves()) {
|
||||||
VectorValues vectorValues = ctx.reader().getVectorValues(fieldName);
|
ByteVectorValues vectorValues = ctx.reader().getByteVectorValues(fieldName);
|
||||||
if (vectorValues == null) {
|
if (vectorValues == null) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -951,7 +971,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
StoredFields storedFields = ctx.reader().storedFields();
|
StoredFields storedFields = ctx.reader().storedFields();
|
||||||
int docId;
|
int docId;
|
||||||
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||||
BytesRef v = vectorValues.binaryValue();
|
BytesRef v = vectorValues.vectorValue();
|
||||||
assertEquals(dimension, v.length);
|
assertEquals(dimension, v.length);
|
||||||
String idString = storedFields.document(docId).getField("id").stringValue();
|
String idString = storedFields.document(docId).getField("id").stringValue();
|
||||||
int id = Integer.parseInt(idString);
|
int id = Integer.parseInt(idString);
|
||||||
|
@ -1141,7 +1161,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
throws IOException {
|
throws IOException {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
if (vector != null) {
|
if (vector != null) {
|
||||||
doc.add(new KnnVectorField(field, vector, similarityFunction));
|
doc.add(new KnnByteVectorField(field, vector, similarityFunction));
|
||||||
}
|
}
|
||||||
doc.add(new NumericDocValuesField("sortkey", sortKey));
|
doc.add(new NumericDocValuesField("sortkey", sortKey));
|
||||||
String idString = Integer.toString(id);
|
String idString = Integer.toString(id);
|
||||||
|
@ -1183,13 +1203,13 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
private BytesRef randomVector8(int dim) {
|
private byte[] randomVector8(int dim) {
|
||||||
float[] v = randomVector(dim);
|
float[] v = randomVector(dim);
|
||||||
byte[] b = new byte[dim];
|
byte[] b = new byte[dim];
|
||||||
for (int i = 0; i < dim; i++) {
|
for (int i = 0; i < dim; i++) {
|
||||||
b[i] = (byte) (v[i] * 127);
|
b[i] = (byte) (v[i] * 127);
|
||||||
}
|
}
|
||||||
return new BytesRef(b);
|
return b;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void testCheckIndexIncludesVectors() throws Exception {
|
public void testCheckIndexIncludesVectors() throws Exception {
|
||||||
|
@ -1297,9 +1317,9 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
if (random().nextInt(4) == 3) {
|
if (random().nextInt(4) == 3) {
|
||||||
switch (vectorEncoding) {
|
switch (vectorEncoding) {
|
||||||
case BYTE -> {
|
case BYTE -> {
|
||||||
BytesRef b = randomVector8(dim);
|
byte[] b = randomVector8(dim);
|
||||||
fieldValuesCheckSum += b.bytes[b.offset];
|
fieldValuesCheckSum += b[0];
|
||||||
doc.add(new KnnVectorField("knn_vector", b, similarityFunction));
|
doc.add(new KnnByteVectorField("knn_vector", new BytesRef(b), similarityFunction));
|
||||||
}
|
}
|
||||||
case FLOAT32 -> {
|
case FLOAT32 -> {
|
||||||
float[] v = randomVector(dim);
|
float[] v = randomVector(dim);
|
||||||
|
@ -1321,18 +1341,36 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
double checksum = 0;
|
double checksum = 0;
|
||||||
int docCount = 0;
|
int docCount = 0;
|
||||||
long sumDocIds = 0;
|
long sumDocIds = 0;
|
||||||
|
switch (vectorEncoding) {
|
||||||
|
case BYTE -> {
|
||||||
for (LeafReaderContext ctx : r.leaves()) {
|
for (LeafReaderContext ctx : r.leaves()) {
|
||||||
VectorValues vectors = ctx.reader().getVectorValues("knn_vector");
|
ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues("knn_vector");
|
||||||
if (vectors != null) {
|
if (byteVectorValues != null) {
|
||||||
|
docCount += byteVectorValues.size();
|
||||||
StoredFields storedFields = ctx.reader().storedFields();
|
StoredFields storedFields = ctx.reader().storedFields();
|
||||||
docCount += vectors.size();
|
while (byteVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
while (vectors.nextDoc() != NO_MORE_DOCS) {
|
checksum += byteVectorValues.vectorValue().bytes[0];
|
||||||
checksum += vectors.vectorValue()[0];
|
Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id"));
|
||||||
Document doc = storedFields.document(vectors.docID(), Set.of("id"));
|
|
||||||
sumDocIds += Integer.parseInt(doc.get("id"));
|
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
case FLOAT32 -> {
|
||||||
|
for (LeafReaderContext ctx : r.leaves()) {
|
||||||
|
VectorValues vectorValues = ctx.reader().getVectorValues("knn_vector");
|
||||||
|
if (vectorValues != null) {
|
||||||
|
docCount += vectorValues.size();
|
||||||
|
StoredFields storedFields = ctx.reader().storedFields();
|
||||||
|
while (vectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||||
|
checksum += vectorValues.vectorValue()[0];
|
||||||
|
Document doc = storedFields.document(vectorValues.docID(), Set.of("id"));
|
||||||
|
sumDocIds += Integer.parseInt(doc.get("id"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
assertEquals(
|
assertEquals(
|
||||||
fieldValuesCheckSum,
|
fieldValuesCheckSum,
|
||||||
checksum,
|
checksum,
|
||||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.lucene.codecs.NormsProducer;
|
||||||
import org.apache.lucene.codecs.StoredFieldsReader;
|
import org.apache.lucene.codecs.StoredFieldsReader;
|
||||||
import org.apache.lucene.codecs.TermVectorsReader;
|
import org.apache.lucene.codecs.TermVectorsReader;
|
||||||
import org.apache.lucene.index.BinaryDocValues;
|
import org.apache.lucene.index.BinaryDocValues;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CodecReader;
|
import org.apache.lucene.index.CodecReader;
|
||||||
import org.apache.lucene.index.DocValuesType;
|
import org.apache.lucene.index.DocValuesType;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
@ -222,6 +223,11 @@ class MergeReaderWrapper extends LeafReader {
|
||||||
return in.getVectorValues(fieldName);
|
return in.getVectorValues(fieldName);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String fieldName) throws IOException {
|
||||||
|
return in.getByteVectorValues(fieldName);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import java.io.IOException;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
import org.apache.lucene.index.BinaryDocValues;
|
import org.apache.lucene.index.BinaryDocValues;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.FieldInfos;
|
import org.apache.lucene.index.FieldInfos;
|
||||||
import org.apache.lucene.index.IndexReader;
|
import org.apache.lucene.index.IndexReader;
|
||||||
import org.apache.lucene.index.LeafMetaData;
|
import org.apache.lucene.index.LeafMetaData;
|
||||||
|
@ -229,6 +230,11 @@ public class QueryUtils {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public TopDocs searchNearestVectors(
|
public TopDocs searchNearestVectors(
|
||||||
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
String field, float[] target, int k, Bits acceptDocs, int visitedLimit) {
|
||||||
|
|
Loading…
Reference in New Issue