mirror of https://github.com/apache/lucene.git
LUCENE-9583: Remove RandomAccessVectorValuesProducer (#1071)
This change folds the `RandomAccessVectorValuesProducer` interface into `RandomAccessVectorValues`. This reduces the number of interfaces and clarifies the cloning/ copying behavior. This is a small simplification related to LUCENE-9583, but does not address the main issue.
This commit is contained in:
parent
0914b537db
commit
8308688d78
|
@ -22,7 +22,6 @@ import java.util.Locale;
|
|||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.hnsw.NeighborQueue;
|
||||
|
@ -72,14 +71,14 @@ public final class Lucene90HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene90HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
vectorValues = vectors.copy();
|
||||
buildVectors = vectors.copy();
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (maxConn <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -381,8 +380,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
|
||||
final int dimension;
|
||||
final int[] ordToDoc;
|
||||
|
@ -468,7 +466,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
public RandomAccessVectorValues copy() {
|
||||
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfo;
|
|||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -422,8 +421,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
static class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
|
||||
private final int dimension;
|
||||
private final int size;
|
||||
|
@ -516,7 +514,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
public RandomAccessVectorValues copy() {
|
||||
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.io.IOException;
|
|||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
|
@ -30,8 +29,7 @@ import org.apache.lucene.util.BytesRef;
|
|||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -144,7 +142,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
}
|
||||
|
||||
|
@ -217,7 +215,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
|
||||
}
|
||||
|
||||
|
@ -294,7 +292,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
|
|||
import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
|
||||
private void writeGraph(
|
||||
IndexOutput graphData,
|
||||
RandomAccessVectorValuesProducer vectorValues,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
long graphDataOffset,
|
||||
long[] offsets,
|
||||
|
@ -239,7 +239,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
beamWidth,
|
||||
Lucene90HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
|
||||
|
||||
for (int ord = 0; ord < offsets.length; ord++) {
|
||||
// write graph
|
||||
|
|
|
@ -24,7 +24,6 @@ import java.util.Locale;
|
|||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.FixedBitSet;
|
||||
|
@ -79,14 +78,14 @@ public final class Lucene91HnswGraphBuilder {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
public Lucene91HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int maxConn,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
vectorValues = vectors.copy();
|
||||
buildVectors = vectors.copy();
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (maxConn <= 0) {
|
||||
throw new IllegalArgumentException("maxConn must be positive");
|
||||
|
|
|
@ -27,7 +27,7 @@ import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
|||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private Lucene91OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
||||
// build graph
|
||||
|
@ -245,7 +245,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
beamWidth,
|
||||
Lucene91HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
|
|
|
@ -29,7 +29,7 @@ import org.apache.lucene.index.BufferingKnnVectorsWriter;
|
|||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
@ -268,7 +268,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
}
|
||||
|
||||
private OnHeapHnswGraph writeGraph(
|
||||
RandomAccessVectorValuesProducer vectorValues,
|
||||
RandomAccessVectorValues vectorValues,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
|
@ -283,7 +283,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
|
|||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
|
||||
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
|
||||
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
|
|
|
@ -29,7 +29,6 @@ import org.apache.lucene.index.CorruptIndexException;
|
|||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -268,7 +267,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
private static class SimpleTextVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
implements RandomAccessVectorValues {
|
||||
|
||||
private final BytesRefBuilder scratch = new BytesRefBuilder();
|
||||
private final FieldEntry entry;
|
||||
|
@ -310,7 +309,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
public RandomAccessVectorValues copy() {
|
||||
return this;
|
||||
}
|
||||
|
||||
|
|
|
@ -422,7 +422,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
beamWidth,
|
||||
HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(offHeapVectors.randomAccess());
|
||||
graph = hnswGraphBuilder.build(offHeapVectors.copy());
|
||||
writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
@ -617,7 +617,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
hnswGraphBuilder =
|
||||
(HnswGraphBuilder<T>)
|
||||
HnswGraphBuilder.create(
|
||||
() -> raVectorValues,
|
||||
raVectorValues,
|
||||
fieldInfo.getVectorEncoding(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
M,
|
||||
|
@ -694,5 +694,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
|
|||
public BytesRef binaryValue(int targetOrd) throws IOException {
|
||||
return (BytesRef) vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,6 @@ import java.io.IOException;
|
|||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
|
@ -30,8 +29,7 @@ import org.apache.lucene.util.BytesRef;
|
|||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/** Read the vector values from the index input. This supports both iterated and random access. */
|
||||
abstract class OffHeapVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
|
@ -150,7 +148,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -226,7 +224,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
|
||||
}
|
||||
|
||||
|
@ -303,7 +301,7 @@ abstract class OffHeapVectorValues extends VectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
public RandomAccessVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
|
|
|
@ -208,7 +208,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
private static class BufferedVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
implements RandomAccessVectorValues {
|
||||
|
||||
final DocsWithFieldSet docsWithField;
|
||||
|
||||
|
@ -236,7 +236,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
|
|||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
public RandomAccessVectorValues copy() {
|
||||
return new BufferedVectorValues(docsWithField, vectors, dimension);
|
||||
}
|
||||
|
||||
|
|
|
@ -49,4 +49,11 @@ public interface RandomAccessVectorValues {
|
|||
* @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}.
|
||||
*/
|
||||
BytesRef binaryValue(int targetOrd) throws IOException;
|
||||
|
||||
/**
|
||||
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
|
||||
* access different values at once, to avoid overwriting the underlying float vector returned by
|
||||
* {@link RandomAccessVectorValues#vectorValue}.
|
||||
*/
|
||||
RandomAccessVectorValues copy() throws IOException;
|
||||
}
|
||||
|
|
|
@ -1,35 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
/**
|
||||
* Something (generally a {@link VectorValues}) that provides a {@link RandomAccessVectorValues}.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public interface RandomAccessVectorValuesProducer {
|
||||
/**
|
||||
* Return a random access interface over this iterator's vectors. Calling the RandomAccess methods
|
||||
* will have no effect on the progress of the iteration or the values returned by this iterator.
|
||||
* Successive calls will retrieve independent copies that do not overwrite each others' returned
|
||||
* values.
|
||||
*/
|
||||
RandomAccessVectorValues randomAccess() throws IOException;
|
||||
}
|
|
@ -113,20 +113,14 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
};
|
||||
|
||||
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
|
||||
public static class SortingVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValuesProducer {
|
||||
|
||||
private final VectorValues delegate;
|
||||
public static class SortingVectorValues extends VectorValues {
|
||||
private final RandomAccessVectorValues randomAccess;
|
||||
private final int[] docIdOffsets;
|
||||
private final int[] ordMap;
|
||||
private int docId = -1;
|
||||
|
||||
/** Sorting VectorValues */
|
||||
public SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.delegate = delegate;
|
||||
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
|
||||
docIdOffsets = new int[sortMap.size()];
|
||||
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
|
||||
this.randomAccess = ((RandomAccessVectorValues) delegate).copy();
|
||||
this.docIdOffsets = new int[sortMap.size()];
|
||||
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
int docID;
|
||||
|
@ -134,16 +128,6 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
|
||||
// set up ordMap to map from new dense ordinal to old dense ordinal
|
||||
ordMap = new int[offset - 1];
|
||||
int ord = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord++] = docIdOffset - 1;
|
||||
}
|
||||
}
|
||||
assert ord == ordMap.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -175,12 +159,12 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegate.dimension();
|
||||
return randomAccess.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegate.size();
|
||||
return randomAccess.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -192,36 +176,5 @@ public abstract class VectorValues extends DocIdSetIterator {
|
|||
public long cost() {
|
||||
return size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() throws IOException {
|
||||
|
||||
// Must make a new delegate randomAccess so that we have our own distinct float[]
|
||||
final RandomAccessVectorValues delegateRA =
|
||||
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
|
||||
|
||||
return new RandomAccessVectorValues() {
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegateRA.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return delegateRA.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) throws IOException {
|
||||
return delegateRA.vectorValue(ordMap[targetOrd]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,6 @@ import java.util.Locale;
|
|||
import java.util.Objects;
|
||||
import java.util.SplittableRandom;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
@ -54,7 +53,7 @@ public final class HnswGraphBuilder<T> {
|
|||
|
||||
private final VectorSimilarityFunction similarityFunction;
|
||||
private final VectorEncoding vectorEncoding;
|
||||
private final RandomAccessVectorValues vectorValues;
|
||||
private final RandomAccessVectorValues vectors;
|
||||
private final SplittableRandom random;
|
||||
private final HnswGraphSearcher<T> graphSearcher;
|
||||
|
||||
|
@ -64,10 +63,10 @@ public final class HnswGraphBuilder<T> {
|
|||
|
||||
// we need two sources of vectors in order to perform diversity check comparisons without
|
||||
// colliding
|
||||
private final RandomAccessVectorValues buildVectors;
|
||||
private final RandomAccessVectorValues vectorsCopy;
|
||||
|
||||
public static HnswGraphBuilder<?> create(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
|
@ -90,15 +89,15 @@ public final class HnswGraphBuilder<T> {
|
|||
* to ensure repeatable construction.
|
||||
*/
|
||||
private HnswGraphBuilder(
|
||||
RandomAccessVectorValuesProducer vectors,
|
||||
RandomAccessVectorValues vectors,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
int M,
|
||||
int beamWidth,
|
||||
long seed)
|
||||
throws IOException {
|
||||
vectorValues = vectors.randomAccess();
|
||||
buildVectors = vectors.randomAccess();
|
||||
this.vectors = vectors;
|
||||
this.vectorsCopy = vectors.copy();
|
||||
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
|
||||
this.similarityFunction = Objects.requireNonNull(similarityFunction);
|
||||
if (M <= 0) {
|
||||
|
@ -119,7 +118,7 @@ public final class HnswGraphBuilder<T> {
|
|||
vectorEncoding,
|
||||
similarityFunction,
|
||||
new NeighborQueue(beamWidth, true),
|
||||
new FixedBitSet(vectorValues.size()));
|
||||
new FixedBitSet(this.vectors.size()));
|
||||
// in scratch we store candidates in reverse order: worse candidates are first
|
||||
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
|
||||
}
|
||||
|
@ -129,26 +128,26 @@ public final class HnswGraphBuilder<T> {
|
|||
* enables efficient retrieval without extra data copying, while avoiding collision of the
|
||||
* returned values.
|
||||
*
|
||||
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
|
||||
* accessor for the vectors
|
||||
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
|
||||
* independent accessor for the vectors
|
||||
*/
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
|
||||
if (vectors == vectorValues) {
|
||||
public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
||||
if (vectorsToAdd == this.vectors) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
|
||||
}
|
||||
if (infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
|
||||
infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors");
|
||||
}
|
||||
addVectors(vectors);
|
||||
addVectors(vectorsToAdd);
|
||||
return hnsw;
|
||||
}
|
||||
|
||||
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
|
||||
private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException {
|
||||
long start = System.nanoTime(), t = start;
|
||||
// start at node 1! node 0 is added implicitly, in the constructor
|
||||
for (int node = 1; node < vectors.size(); node++) {
|
||||
addGraphNode(node, vectors);
|
||||
for (int node = 1; node < vectorsToAdd.size(); node++) {
|
||||
addGraphNode(node, vectorsToAdd);
|
||||
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
|
||||
t = printGraphBuildStatus(node, start, t);
|
||||
}
|
||||
|
@ -178,12 +177,12 @@ public final class HnswGraphBuilder<T> {
|
|||
|
||||
// for levels > nodeLevel search with topk = 1
|
||||
for (int level = curMaxLevel; level > nodeLevel; level--) {
|
||||
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
|
||||
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectors, hnsw);
|
||||
eps = new int[] {candidates.pop()};
|
||||
}
|
||||
// for levels <= nodeLevel search with topk = beamWidth, and add connections
|
||||
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
|
||||
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
|
||||
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectors, hnsw);
|
||||
eps = candidates.nodes();
|
||||
hnsw.addNode(level, node);
|
||||
addDiverseNeighbors(level, node, candidates);
|
||||
|
@ -282,8 +281,8 @@ public final class HnswGraphBuilder<T> {
|
|||
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
|
||||
throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
|
||||
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score);
|
||||
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -291,7 +290,7 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
|
@ -303,7 +302,7 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = 0; i < neighbors.size(); i++) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
if (neighborSimilarity >= score) {
|
||||
return false;
|
||||
}
|
||||
|
@ -328,9 +327,9 @@ public final class HnswGraphBuilder<T> {
|
|||
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
|
||||
return switch (vectorEncoding) {
|
||||
case BYTE -> isWorstNonDiverse(
|
||||
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
case FLOAT32 -> isWorstNonDiverse(
|
||||
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -339,7 +338,7 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
|
||||
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
|
@ -353,7 +352,7 @@ public final class HnswGraphBuilder<T> {
|
|||
throws IOException {
|
||||
for (int i = candidateIndex - 1; i > -0; i--) {
|
||||
float neighborSimilarity =
|
||||
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
|
||||
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
|
||||
// node i is too similar to node j given its score relative to the base node
|
||||
if (neighborSimilarity >= minAcceptedSimilarity) {
|
||||
return false;
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.lucene.util.hnsw;
|
|||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.io.OutputStream;
|
||||
import java.lang.management.ManagementFactory;
|
||||
|
@ -54,8 +53,6 @@ import org.apache.lucene.index.IndexWriter;
|
|||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.ConstantScoreScorer;
|
||||
|
@ -278,9 +275,6 @@ public class KnnGraphTester {
|
|||
testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
|
||||
}
|
||||
break;
|
||||
case "-dump":
|
||||
dumpGraph(docVectorsPath);
|
||||
break;
|
||||
case "-stats":
|
||||
printFanoutHist(indexPath);
|
||||
break;
|
||||
|
@ -308,41 +302,6 @@ public class KnnGraphTester {
|
|||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private void dumpGraph(Path docsPath) throws IOException {
|
||||
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
|
||||
RandomAccessVectorValues values = vectors.randomAccess();
|
||||
HnswGraphBuilder<float[]> builder =
|
||||
(HnswGraphBuilder<float[]>)
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
|
||||
// start at node 1
|
||||
for (int i = 1; i < numDocs; i++) {
|
||||
builder.addGraphNode(i, values);
|
||||
System.out.println("\nITERATION " + i);
|
||||
dumpGraph(builder.hnsw);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void dumpGraph(OnHeapHnswGraph hnsw) {
|
||||
for (int i = 0; i < hnsw.size(); i++) {
|
||||
NeighborArray neighbors = hnsw.getNeighbors(0, i);
|
||||
System.out.printf(Locale.ROOT, "%5d", i);
|
||||
NeighborArray sorted = new NeighborArray(neighbors.size(), true);
|
||||
for (int j = 0; j < neighbors.size(); j++) {
|
||||
int node = neighbors.node[j];
|
||||
float score = neighbors.score[j];
|
||||
sorted.add(node, score);
|
||||
}
|
||||
new NeighborArraySorter(sorted).sort(0, sorted.size());
|
||||
for (int j = 0; j < sorted.size(); j++) {
|
||||
System.out.printf(Locale.ROOT, " [%d, %.4f]", sorted.node[j], sorted.score[j]);
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressForbidden(reason = "Prints stuff")
|
||||
private void forceMerge() throws IOException {
|
||||
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
|
||||
|
@ -783,66 +742,6 @@ public class KnnGraphTester {
|
|||
System.exit(1);
|
||||
}
|
||||
|
||||
class BinaryFileVectors implements RandomAccessVectorValuesProducer, Closeable {
|
||||
|
||||
private final int size;
|
||||
private final FileChannel in;
|
||||
private final FloatBuffer mmap;
|
||||
|
||||
BinaryFileVectors(Path filePath) throws IOException {
|
||||
in = FileChannel.open(filePath);
|
||||
long totalBytes = (long) numDocs * dim * Float.BYTES;
|
||||
if (totalBytes > Integer.MAX_VALUE) {
|
||||
throw new IllegalArgumentException("input over 2GB not supported");
|
||||
}
|
||||
int vectorByteSize = dim * Float.BYTES;
|
||||
size = (int) (totalBytes / vectorByteSize);
|
||||
mmap =
|
||||
in.map(FileChannel.MapMode.READ_ONLY, 0, totalBytes)
|
||||
.order(ByteOrder.LITTLE_ENDIAN)
|
||||
.asFloatBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
in.close();
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
return new Values();
|
||||
}
|
||||
|
||||
class Values implements RandomAccessVectorValues {
|
||||
|
||||
float[] vector = new float[dim];
|
||||
FloatBuffer source = mmap.slice();
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
int pos = targetOrd * dim;
|
||||
source.position(pos);
|
||||
source.get(vector);
|
||||
return vector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public BytesRef binaryValue(int targetOrd) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static class NeighborArraySorter extends IntroSorter {
|
||||
private final int[] node;
|
||||
private final float[] score;
|
||||
|
|
|
@ -18,13 +18,11 @@
|
|||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.BytesRef;
|
||||
|
||||
class MockVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
private final float[] scratch;
|
||||
|
||||
protected final int dimension;
|
||||
|
@ -53,6 +51,7 @@ class MockVectorValues extends VectorValues
|
|||
binaryValue.length = dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MockVectorValues copy() {
|
||||
return new MockVectorValues(values);
|
||||
}
|
||||
|
@ -81,11 +80,6 @@ class MockVectorValues extends VectorValues
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
return copy();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue(int targetOrd) {
|
||||
return denseValues[targetOrd];
|
||||
|
|
|
@ -43,7 +43,6 @@ import org.apache.lucene.index.IndexWriter;
|
|||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
|
@ -94,7 +93,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
|
||||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
|
||||
HnswGraph hnsw = builder.build(vectors);
|
||||
HnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
// Recreate the graph while indexing with the same random seed and write it out
|
||||
HnswGraphBuilder.randSeed = seed;
|
||||
|
@ -275,13 +274,13 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// run some searches
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -316,14 +315,14 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// the first 10 docs must not be deleted to ensure the expected recall
|
||||
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
|
||||
NeighborQueue nn =
|
||||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -349,7 +348,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
// Only mark a few vectors as accepted
|
||||
BitSet acceptOrds = new FixedBitSet(vectors.size);
|
||||
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
|
||||
|
@ -362,7 +361,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
numAccepted,
|
||||
vectors.randomAccess(),
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -386,7 +385,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
// Skip over half of the documents that are closest to the query vector
|
||||
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
|
||||
|
@ -397,7 +396,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
10,
|
||||
vectors.randomAccess(),
|
||||
vectors.copy(),
|
||||
VectorEncoding.FLOAT32,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -423,7 +422,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
|
||||
int topK = 50;
|
||||
int visitedLimit = topK + random().nextInt(5);
|
||||
|
@ -431,7 +430,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphSearcher.search(
|
||||
getTargetVector(),
|
||||
topK,
|
||||
vectors.randomAccess(),
|
||||
vectors.copy(),
|
||||
vectorEncoding,
|
||||
similarityFunction,
|
||||
hnsw,
|
||||
|
@ -496,15 +495,16 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
|
||||
// node 0 is added by the builder constructor
|
||||
// builder.addGraphNode(vectors.vectorValue(0));
|
||||
builder.addGraphNode(1, vectors);
|
||||
builder.addGraphNode(2, vectors);
|
||||
RandomAccessVectorValues vectorsCopy = vectors.copy();
|
||||
builder.addGraphNode(1, vectorsCopy);
|
||||
builder.addGraphNode(2, vectorsCopy);
|
||||
// now every node has tried to attach every other node as a neighbor, but
|
||||
// some were excluded based on diversity check.
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
|
||||
builder.addGraphNode(3, vectors);
|
||||
builder.addGraphNode(3, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
// we added 3 here
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
|
||||
|
@ -512,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertLevel0Neighbors(builder.hnsw, 3, 1);
|
||||
|
||||
// supplant an existing neighbor
|
||||
builder.addGraphNode(4, vectors);
|
||||
builder.addGraphNode(4, vectorsCopy);
|
||||
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
|
||||
|
@ -521,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
|
||||
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
|
||||
|
||||
builder.addGraphNode(5, vectors);
|
||||
builder.addGraphNode(5, vectorsCopy);
|
||||
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
|
||||
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
|
||||
assertLevel0Neighbors(builder.hnsw, 2, 0);
|
||||
|
@ -550,7 +550,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
HnswGraphBuilder<?> builder =
|
||||
HnswGraphBuilder.create(
|
||||
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors);
|
||||
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
|
||||
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
|
||||
|
||||
int totalMatches = 0;
|
||||
|
@ -617,8 +617,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
}
|
||||
|
||||
/** Returns vectors evenly distributed around the upper unit semicircle. */
|
||||
static class CircularVectorValues extends VectorValues
|
||||
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
|
||||
static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues {
|
||||
private final int size;
|
||||
private final float[] value;
|
||||
private final BytesRef binaryValue;
|
||||
|
@ -631,6 +630,7 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
binaryValue = new BytesRef(new byte[2]);
|
||||
}
|
||||
|
||||
@Override
|
||||
public CircularVectorValues copy() {
|
||||
return new CircularVectorValues(size);
|
||||
}
|
||||
|
@ -650,11 +650,6 @@ public class TestHnswGraph extends LuceneTestCase {
|
|||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues randomAccess() {
|
||||
return new CircularVectorValues(size);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
|
|
Loading…
Reference in New Issue