LUCENE-9583: Remove RandomAccessVectorValuesProducer (#1071)

This change folds the `RandomAccessVectorValuesProducer` interface into
`RandomAccessVectorValues`. This reduces the number of interfaces and clarifies
the cloning/ copying behavior.

This is a small simplification related to LUCENE-9583, but does not address the
main issue.
This commit is contained in:
Julie Tibshirani 2022-08-19 18:04:05 -07:00 committed by GitHub
parent 0914b537db
commit 8308688d78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 99 additions and 293 deletions

View File

@ -22,7 +22,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;
@ -72,14 +71,14 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
long seed)
throws IOException {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
vectorValues = vectors.copy();
buildVectors = vectors.copy();
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (maxConn <= 0) {
throw new IllegalArgumentException("maxConn must be positive");

View File

@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -381,8 +380,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
final int dimension;
final int[] ordToDoc;
@ -468,7 +466,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
}
@Override
public RandomAccessVectorValues randomAccess() {
public RandomAccessVectorValues copy() {
return new OffHeapVectorValues(dimension, ordToDoc, dataIn.clone());
}

View File

@ -33,7 +33,6 @@ import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -422,8 +421,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
static class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
private final int dimension;
private final int size;
@ -516,7 +514,7 @@ public final class Lucene91HnswVectorsReader extends KnnVectorsReader {
}
@Override
public RandomAccessVectorValues randomAccess() {
public RandomAccessVectorValues copy() {
return new OffHeapVectorValues(dimension, size, ordToDoc, dataIn.clone());
}

View File

@ -21,7 +21,6 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
@ -30,8 +29,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
protected final int dimension;
protected final int size;
@ -144,7 +142,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
}
@ -217,7 +215,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone());
}
@ -294,7 +292,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -26,7 +26,7 @@ import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -224,7 +224,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
private void writeGraph(
IndexOutput graphData,
RandomAccessVectorValuesProducer vectorValues,
RandomAccessVectorValues vectorValues,
VectorSimilarityFunction similarityFunction,
long graphDataOffset,
long[] offsets,
@ -239,7 +239,7 @@ public final class Lucene90HnswVectorsWriter extends BufferingKnnVectorsWriter {
beamWidth,
Lucene90HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
Lucene90OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
for (int ord = 0; ord < offsets.length; ord++) {
// write graph

View File

@ -24,7 +24,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.FixedBitSet;
@ -79,14 +78,14 @@ public final class Lucene91HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene91HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
RandomAccessVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
long seed)
throws IOException {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
vectorValues = vectors.copy();
buildVectors = vectors.copy();
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (maxConn <= 0) {
throw new IllegalArgumentException("maxConn must be positive");

View File

@ -27,7 +27,7 @@ import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -233,7 +233,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private Lucene91OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues, VectorSimilarityFunction similarityFunction)
RandomAccessVectorValues vectorValues, VectorSimilarityFunction similarityFunction)
throws IOException {
// build graph
@ -245,7 +245,7 @@ public final class Lucene91HnswVectorsWriter extends BufferingKnnVectorsWriter {
beamWidth,
Lucene91HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
Lucene91OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();

View File

@ -29,7 +29,7 @@ import org.apache.lucene.index.BufferingKnnVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
@ -268,7 +268,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
}
private OnHeapHnswGraph writeGraph(
RandomAccessVectorValuesProducer vectorValues,
RandomAccessVectorValues vectorValues,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
@ -283,7 +283,7 @@ public final class Lucene92HnswVectorsWriter extends BufferingKnnVectorsWriter {
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.randomAccess());
OnHeapHnswGraph graph = hnswGraphBuilder.build(vectorValues.copy());
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();

View File

@ -29,7 +29,6 @@ import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -268,7 +267,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
private static class SimpleTextVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
implements RandomAccessVectorValues {
private final BytesRefBuilder scratch = new BytesRefBuilder();
private final FieldEntry entry;
@ -310,7 +309,7 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
}
@Override
public RandomAccessVectorValues randomAccess() {
public RandomAccessVectorValues copy() {
return this;
}

View File

@ -422,7 +422,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
beamWidth,
HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(offHeapVectors.randomAccess());
graph = hnswGraphBuilder.build(offHeapVectors.copy());
writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
@ -617,7 +617,7 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
hnswGraphBuilder =
(HnswGraphBuilder<T>)
HnswGraphBuilder.create(
() -> raVectorValues,
raVectorValues,
fieldInfo.getVectorEncoding(),
fieldInfo.getVectorSimilarityFunction(),
M,
@ -694,5 +694,10 @@ public final class Lucene94HnswVectorsWriter extends KnnVectorsWriter {
public BytesRef binaryValue(int targetOrd) throws IOException {
return (BytesRef) vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues copy() throws IOException {
return this;
}
}
}

View File

@ -21,7 +21,6 @@ import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
@ -30,8 +29,7 @@ import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/** Read the vector values from the index input. This supports both iterated and random access. */
abstract class OffHeapVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
abstract class OffHeapVectorValues extends VectorValues implements RandomAccessVectorValues {
protected final int dimension;
protected final int size;
@ -150,7 +148,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone(), byteSize);
}
@ -226,7 +224,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(fieldEntry, dataIn, slice.clone(), byteSize);
}
@ -303,7 +301,7 @@ abstract class OffHeapVectorValues extends VectorValues
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
public RandomAccessVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}

View File

@ -208,7 +208,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
private static class BufferedVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
implements RandomAccessVectorValues {
final DocsWithFieldSet docsWithField;
@ -236,7 +236,7 @@ public abstract class BufferingKnnVectorsWriter extends KnnVectorsWriter {
}
@Override
public RandomAccessVectorValues randomAccess() {
public RandomAccessVectorValues copy() {
return new BufferedVectorValues(docsWithField, vectors, dimension);
}

View File

@ -49,4 +49,11 @@ public interface RandomAccessVectorValues {
* @param targetOrd a valid ordinal, &ge; 0 and &lt; {@link #size()}.
*/
BytesRef binaryValue(int targetOrd) throws IOException;
/**
* Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to
* access different values at once, to avoid overwriting the underlying float vector returned by
* {@link RandomAccessVectorValues#vectorValue}.
*/
RandomAccessVectorValues copy() throws IOException;
}

View File

@ -1,35 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.index;
import java.io.IOException;
/**
* Something (generally a {@link VectorValues}) that provides a {@link RandomAccessVectorValues}.
*
* @lucene.experimental
*/
public interface RandomAccessVectorValuesProducer {
/**
* Return a random access interface over this iterator's vectors. Calling the RandomAccess methods
* will have no effect on the progress of the iteration or the values returned by this iterator.
* Successive calls will retrieve independent copies that do not overwrite each others' returned
* values.
*/
RandomAccessVectorValues randomAccess() throws IOException;
}

View File

@ -113,20 +113,14 @@ public abstract class VectorValues extends DocIdSetIterator {
};
/** Sorting VectorValues that iterate over documents in the order of the provided sortMap */
public static class SortingVectorValues extends VectorValues
implements RandomAccessVectorValuesProducer {
private final VectorValues delegate;
public static class SortingVectorValues extends VectorValues {
private final RandomAccessVectorValues randomAccess;
private final int[] docIdOffsets;
private final int[] ordMap;
private int docId = -1;
/** Sorting VectorValues */
public SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.delegate = delegate;
randomAccess = ((RandomAccessVectorValuesProducer) delegate).randomAccess();
docIdOffsets = new int[sortMap.size()];
SortingVectorValues(VectorValues delegate, Sorter.DocMap sortMap) throws IOException {
this.randomAccess = ((RandomAccessVectorValues) delegate).copy();
this.docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
int docID;
@ -134,16 +128,6 @@ public abstract class VectorValues extends DocIdSetIterator {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
// set up ordMap to map from new dense ordinal to old dense ordinal
ordMap = new int[offset - 1];
int ord = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord++] = docIdOffset - 1;
}
}
assert ord == ordMap.length;
}
@Override
@ -175,12 +159,12 @@ public abstract class VectorValues extends DocIdSetIterator {
@Override
public int dimension() {
return delegate.dimension();
return randomAccess.dimension();
}
@Override
public int size() {
return delegate.size();
return randomAccess.size();
}
@Override
@ -192,36 +176,5 @@ public abstract class VectorValues extends DocIdSetIterator {
public long cost() {
return size();
}
@Override
public RandomAccessVectorValues randomAccess() throws IOException {
// Must make a new delegate randomAccess so that we have our own distinct float[]
final RandomAccessVectorValues delegateRA =
((RandomAccessVectorValuesProducer) SortingVectorValues.this.delegate).randomAccess();
return new RandomAccessVectorValues() {
@Override
public int size() {
return delegateRA.size();
}
@Override
public int dimension() {
return delegateRA.dimension();
}
@Override
public float[] vectorValue(int targetOrd) throws IOException {
return delegateRA.vectorValue(ordMap[targetOrd]);
}
@Override
public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException();
}
};
}
}
}

View File

@ -24,7 +24,6 @@ import java.util.Locale;
import java.util.Objects;
import java.util.SplittableRandom;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.BytesRef;
@ -54,7 +53,7 @@ public final class HnswGraphBuilder<T> {
private final VectorSimilarityFunction similarityFunction;
private final VectorEncoding vectorEncoding;
private final RandomAccessVectorValues vectorValues;
private final RandomAccessVectorValues vectors;
private final SplittableRandom random;
private final HnswGraphSearcher<T> graphSearcher;
@ -64,10 +63,10 @@ public final class HnswGraphBuilder<T> {
// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues buildVectors;
private final RandomAccessVectorValues vectorsCopy;
public static HnswGraphBuilder<?> create(
RandomAccessVectorValuesProducer vectors,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
@ -90,15 +89,15 @@ public final class HnswGraphBuilder<T> {
* to ensure repeatable construction.
*/
private HnswGraphBuilder(
RandomAccessVectorValuesProducer vectors,
RandomAccessVectorValues vectors,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction,
int M,
int beamWidth,
long seed)
throws IOException {
vectorValues = vectors.randomAccess();
buildVectors = vectors.randomAccess();
this.vectors = vectors;
this.vectorsCopy = vectors.copy();
this.vectorEncoding = Objects.requireNonNull(vectorEncoding);
this.similarityFunction = Objects.requireNonNull(similarityFunction);
if (M <= 0) {
@ -119,7 +118,7 @@ public final class HnswGraphBuilder<T> {
vectorEncoding,
similarityFunction,
new NeighborQueue(beamWidth, true),
new FixedBitSet(vectorValues.size()));
new FixedBitSet(this.vectors.size()));
// in scratch we store candidates in reverse order: worse candidates are first
scratch = new NeighborArray(Math.max(beamWidth, M + 1), false);
}
@ -129,26 +128,26 @@ public final class HnswGraphBuilder<T> {
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent
* accessor for the vectors
* @param vectorsToAdd the vectors for which to build a nearest neighbors graph. Must be an
* independent accessor for the vectors
*/
public OnHeapHnswGraph build(RandomAccessVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
public OnHeapHnswGraph build(RandomAccessVectorValues vectorsToAdd) throws IOException {
if (vectorsToAdd == this.vectors) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
}
if (infoStream.isEnabled(HNSW_COMPONENT)) {
infoStream.message(HNSW_COMPONENT, "build graph from " + vectors.size() + " vectors");
infoStream.message(HNSW_COMPONENT, "build graph from " + vectorsToAdd.size() + " vectors");
}
addVectors(vectors);
addVectors(vectorsToAdd);
return hnsw;
}
private void addVectors(RandomAccessVectorValues vectors) throws IOException {
private void addVectors(RandomAccessVectorValues vectorsToAdd) throws IOException {
long start = System.nanoTime(), t = start;
// start at node 1! node 0 is added implicitly, in the constructor
for (int node = 1; node < vectors.size(); node++) {
addGraphNode(node, vectors);
for (int node = 1; node < vectorsToAdd.size(); node++) {
addGraphNode(node, vectorsToAdd);
if ((node % 10000 == 0) && infoStream.isEnabled(HNSW_COMPONENT)) {
t = printGraphBuildStatus(node, start, t);
}
@ -178,12 +177,12 @@ public final class HnswGraphBuilder<T> {
// for levels > nodeLevel search with topk = 1
for (int level = curMaxLevel; level > nodeLevel; level--) {
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectorValues, hnsw);
candidates = graphSearcher.searchLevel(value, 1, level, eps, vectors, hnsw);
eps = new int[] {candidates.pop()};
}
// for levels <= nodeLevel search with topk = beamWidth, and add connections
for (int level = Math.min(nodeLevel, curMaxLevel); level >= 0; level--) {
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectorValues, hnsw);
candidates = graphSearcher.searchLevel(value, beamWidth, level, eps, vectors, hnsw);
eps = candidates.nodes();
hnsw.addNode(level, node);
addDiverseNeighbors(level, node, candidates);
@ -282,8 +281,8 @@ public final class HnswGraphBuilder<T> {
private boolean isDiverse(int candidate, NeighborArray neighbors, float score)
throws IOException {
return switch (vectorEncoding) {
case BYTE -> isDiverse(vectorValues.binaryValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectorValues.vectorValue(candidate), neighbors, score);
case BYTE -> isDiverse(vectors.binaryValue(candidate), neighbors, score);
case FLOAT32 -> isDiverse(vectors.vectorValue(candidate), neighbors, score);
};
}
@ -291,7 +290,7 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@ -303,7 +302,7 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = 0; i < neighbors.size(); i++) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
if (neighborSimilarity >= score) {
return false;
}
@ -328,9 +327,9 @@ public final class HnswGraphBuilder<T> {
int candidate, NeighborArray neighbors, float minAcceptedSimilarity) throws IOException {
return switch (vectorEncoding) {
case BYTE -> isWorstNonDiverse(
candidate, vectorValues.binaryValue(candidate), neighbors, minAcceptedSimilarity);
candidate, vectors.binaryValue(candidate), neighbors, minAcceptedSimilarity);
case FLOAT32 -> isWorstNonDiverse(
candidate, vectorValues.vectorValue(candidate), neighbors, minAcceptedSimilarity);
candidate, vectors.vectorValue(candidate), neighbors, minAcceptedSimilarity);
};
}
@ -339,7 +338,7 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.vectorValue(neighbors.node[i]));
similarityFunction.compare(candidate, vectorsCopy.vectorValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;
@ -353,7 +352,7 @@ public final class HnswGraphBuilder<T> {
throws IOException {
for (int i = candidateIndex - 1; i > -0; i--) {
float neighborSimilarity =
similarityFunction.compare(candidate, buildVectors.binaryValue(neighbors.node[i]));
similarityFunction.compare(candidate, vectorsCopy.binaryValue(neighbors.node[i]));
// node i is too similar to node j given its score relative to the base node
if (neighborSimilarity >= minAcceptedSimilarity) {
return false;

View File

@ -19,7 +19,6 @@ package org.apache.lucene.util.hnsw;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
@ -54,8 +53,6 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ConstantScoreScorer;
@ -278,9 +275,6 @@ public class KnnGraphTester {
testSearch(indexPath, queryPath, null, getNN(docVectorsPath, queryPath));
}
break;
case "-dump":
dumpGraph(docVectorsPath);
break;
case "-stats":
printFanoutHist(indexPath);
break;
@ -308,41 +302,6 @@ public class KnnGraphTester {
}
}
@SuppressWarnings("unchecked")
private void dumpGraph(Path docsPath) throws IOException {
try (BinaryFileVectors vectors = new BinaryFileVectors(docsPath)) {
RandomAccessVectorValues values = vectors.randomAccess();
HnswGraphBuilder<float[]> builder =
(HnswGraphBuilder<float[]>)
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, maxConn, beamWidth, 0);
// start at node 1
for (int i = 1; i < numDocs; i++) {
builder.addGraphNode(i, values);
System.out.println("\nITERATION " + i);
dumpGraph(builder.hnsw);
}
}
}
private void dumpGraph(OnHeapHnswGraph hnsw) {
for (int i = 0; i < hnsw.size(); i++) {
NeighborArray neighbors = hnsw.getNeighbors(0, i);
System.out.printf(Locale.ROOT, "%5d", i);
NeighborArray sorted = new NeighborArray(neighbors.size(), true);
for (int j = 0; j < neighbors.size(); j++) {
int node = neighbors.node[j];
float score = neighbors.score[j];
sorted.add(node, score);
}
new NeighborArraySorter(sorted).sort(0, sorted.size());
for (int j = 0; j < sorted.size(); j++) {
System.out.printf(Locale.ROOT, " [%d, %.4f]", sorted.node[j], sorted.score[j]);
}
System.out.println();
}
}
@SuppressForbidden(reason = "Prints stuff")
private void forceMerge() throws IOException {
IndexWriterConfig iwc = new IndexWriterConfig().setOpenMode(IndexWriterConfig.OpenMode.APPEND);
@ -783,66 +742,6 @@ public class KnnGraphTester {
System.exit(1);
}
class BinaryFileVectors implements RandomAccessVectorValuesProducer, Closeable {
private final int size;
private final FileChannel in;
private final FloatBuffer mmap;
BinaryFileVectors(Path filePath) throws IOException {
in = FileChannel.open(filePath);
long totalBytes = (long) numDocs * dim * Float.BYTES;
if (totalBytes > Integer.MAX_VALUE) {
throw new IllegalArgumentException("input over 2GB not supported");
}
int vectorByteSize = dim * Float.BYTES;
size = (int) (totalBytes / vectorByteSize);
mmap =
in.map(FileChannel.MapMode.READ_ONLY, 0, totalBytes)
.order(ByteOrder.LITTLE_ENDIAN)
.asFloatBuffer();
}
@Override
public void close() throws IOException {
in.close();
}
@Override
public RandomAccessVectorValues randomAccess() {
return new Values();
}
class Values implements RandomAccessVectorValues {
float[] vector = new float[dim];
FloatBuffer source = mmap.slice();
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return dim;
}
@Override
public float[] vectorValue(int targetOrd) {
int pos = targetOrd * dim;
source.position(pos);
source.get(vector);
return vector;
}
@Override
public BytesRef binaryValue(int targetOrd) {
throw new UnsupportedOperationException();
}
}
}
static class NeighborArraySorter extends IntroSorter {
private final int[] node;
private final float[] score;

View File

@ -18,13 +18,11 @@
package org.apache.lucene.util.hnsw;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.BytesRef;
class MockVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
class MockVectorValues extends VectorValues implements RandomAccessVectorValues {
private final float[] scratch;
protected final int dimension;
@ -53,6 +51,7 @@ class MockVectorValues extends VectorValues
binaryValue.length = dimension;
}
@Override
public MockVectorValues copy() {
return new MockVectorValues(values);
}
@ -81,11 +80,6 @@ class MockVectorValues extends VectorValues
}
}
@Override
public RandomAccessVectorValues randomAccess() {
return copy();
}
@Override
public float[] vectorValue(int targetOrd) {
return denseValues[targetOrd];

View File

@ -43,7 +43,6 @@ import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
@ -94,7 +93,7 @@ public class TestHnswGraph extends LuceneTestCase {
RandomVectorValues v2 = vectors.copy(), v3 = vectors.copy();
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(vectors, vectorEncoding, similarityFunction, M, beamWidth, seed);
HnswGraph hnsw = builder.build(vectors);
HnswGraph hnsw = builder.build(vectors.copy());
// Recreate the graph while indexing with the same random seed and write it out
HnswGraphBuilder.randSeed = seed;
@ -275,13 +274,13 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// run some searches
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.randomAccess(),
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
@ -316,14 +315,14 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// the first 10 docs must not be deleted to ensure the expected recall
Bits acceptOrds = createRandomAcceptOrds(10, vectors.size);
NeighborQueue nn =
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.randomAccess(),
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
@ -349,7 +348,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Only mark a few vectors as accepted
BitSet acceptOrds = new FixedBitSet(vectors.size);
for (int i = 0; i < vectors.size; i += random().nextInt(15, 20)) {
@ -362,7 +361,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphSearcher.search(
getTargetVector(),
numAccepted,
vectors.randomAccess(),
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
@ -386,7 +385,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, VectorEncoding.FLOAT32, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
// Skip over half of the documents that are closest to the query vector
FixedBitSet acceptOrds = new FixedBitSet(nDoc);
@ -397,7 +396,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphSearcher.search(
getTargetVector(),
10,
vectors.randomAccess(),
vectors.copy(),
VectorEncoding.FLOAT32,
similarityFunction,
hnsw,
@ -423,7 +422,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 16, 100, random().nextInt());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
int topK = 50;
int visitedLimit = topK + random().nextInt(5);
@ -431,7 +430,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphSearcher.search(
getTargetVector(),
topK,
vectors.randomAccess(),
vectors.copy(),
vectorEncoding,
similarityFunction,
hnsw,
@ -496,15 +495,16 @@ public class TestHnswGraph extends LuceneTestCase {
vectors, vectorEncoding, similarityFunction, 2, 10, random().nextInt());
// node 0 is added by the builder constructor
// builder.addGraphNode(vectors.vectorValue(0));
builder.addGraphNode(1, vectors);
builder.addGraphNode(2, vectors);
RandomAccessVectorValues vectorsCopy = vectors.copy();
builder.addGraphNode(1, vectorsCopy);
builder.addGraphNode(2, vectorsCopy);
// now every node has tried to attach every other node as a neighbor, but
// some were excluded based on diversity check.
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0);
assertLevel0Neighbors(builder.hnsw, 2, 0);
builder.addGraphNode(3, vectors);
builder.addGraphNode(3, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
// we added 3 here
assertLevel0Neighbors(builder.hnsw, 1, 0, 3);
@ -512,7 +512,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1);
// supplant an existing neighbor
builder.addGraphNode(4, vectors);
builder.addGraphNode(4, vectorsCopy);
// 4 is the same distance from 0 that 2 is; we leave the existing node in place
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4);
@ -521,7 +521,7 @@ public class TestHnswGraph extends LuceneTestCase {
assertLevel0Neighbors(builder.hnsw, 3, 1, 4);
assertLevel0Neighbors(builder.hnsw, 4, 1, 3);
builder.addGraphNode(5, vectors);
builder.addGraphNode(5, vectorsCopy);
assertLevel0Neighbors(builder.hnsw, 0, 1, 2);
assertLevel0Neighbors(builder.hnsw, 1, 0, 3, 4, 5);
assertLevel0Neighbors(builder.hnsw, 2, 0);
@ -550,7 +550,7 @@ public class TestHnswGraph extends LuceneTestCase {
HnswGraphBuilder<?> builder =
HnswGraphBuilder.create(
vectors, vectorEncoding, similarityFunction, 10, 30, random().nextLong());
OnHeapHnswGraph hnsw = builder.build(vectors);
OnHeapHnswGraph hnsw = builder.build(vectors.copy());
Bits acceptOrds = random().nextBoolean() ? null : createRandomAcceptOrds(0, size);
int totalMatches = 0;
@ -617,8 +617,7 @@ public class TestHnswGraph extends LuceneTestCase {
}
/** Returns vectors evenly distributed around the upper unit semicircle. */
static class CircularVectorValues extends VectorValues
implements RandomAccessVectorValues, RandomAccessVectorValuesProducer {
static class CircularVectorValues extends VectorValues implements RandomAccessVectorValues {
private final int size;
private final float[] value;
private final BytesRef binaryValue;
@ -631,6 +630,7 @@ public class TestHnswGraph extends LuceneTestCase {
binaryValue = new BytesRef(new byte[2]);
}
@Override
public CircularVectorValues copy() {
return new CircularVectorValues(size);
}
@ -650,11 +650,6 @@ public class TestHnswGraph extends LuceneTestCase {
return vectorValue(doc);
}
@Override
public RandomAccessVectorValues randomAccess() {
return new CircularVectorValues(size);
}
@Override
public int docID() {
return doc;