Seal random or sequential access for the knn codec depending on their access pattern

This commit is contained in:
Jim Ferenczi 2024-12-17 13:34:02 +00:00
parent 6867430140
commit 9649461e35
12 changed files with 77 additions and 37 deletions

View File

@ -30,6 +30,7 @@ import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsWriter;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsWriter;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ReadAdvice;
class Lucene99RWHnswScalarQuantizationVectorsFormat
extends Lucene99HnswScalarQuantizedVectorsFormat {
@ -54,7 +55,7 @@ class Lucene99RWHnswScalarQuantizationVectorsFormat
static class Lucene99RWScalarQuantizedFormat extends Lucene99ScalarQuantizedVectorsFormat {
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer());
new Lucene99FlatVectorsFormat(new DefaultFlatVectorScorer(), ReadAdvice.RANDOM);
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
@ -62,7 +63,8 @@ class Lucene99RWHnswScalarQuantizationVectorsFormat
state,
null,
rawVectorFormat.fieldsWriter(state),
new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()));
new ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()),
ReadAdvice.RANDOM);
}
}
}

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
@ -128,7 +129,7 @@ public final class HnswBitVectorsFormat extends KnnVectorsFormat {
} else {
this.mergeExec = null;
}
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer());
this.flatVectorsFormat = new Lucene99FlatVectorsFormat(new FlatBitVectorsScorer(), ReadAdvice.RANDOM);
}
@Override

View File

@ -27,6 +27,7 @@ import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.ReadAdvice;
/**
* Lucene 9.9 flat vector format, which encodes numeric vector values
@ -78,21 +79,23 @@ public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
private final FlatVectorsScorer vectorsScorer;
private final ReadAdvice readAdvice;
/** Constructs a format */
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer) {
public Lucene99FlatVectorsFormat(FlatVectorsScorer vectorsScorer, ReadAdvice readAdvice) {
super(NAME);
this.vectorsScorer = vectorsScorer;
this.readAdvice = readAdvice;
}
@Override
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99FlatVectorsWriter(state, vectorsScorer);
return new Lucene99FlatVectorsWriter(state, vectorsScorer, readAdvice);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99FlatVectorsReader(state, vectorsScorer);
return new Lucene99FlatVectorsReader(state, vectorsScorer, readAdvice);
}
@Override

View File

@ -59,7 +59,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
private final IndexInput vectorData;
private final FieldInfos fieldInfos;
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer)
public Lucene99FlatVectorsReader(SegmentReadState state, FlatVectorsScorer scorer, ReadAdvice readAdvice)
throws IOException {
super(scorer);
int versionMeta = readMetadata(state);
@ -72,9 +72,7 @@ public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
versionMeta,
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
// Flat formats are used to randomly access vectors from their node ID that is stored
// in the HNSW graph.
state.context.withReadAdvice(ReadAdvice.RANDOM));
state.context.withReadAdvice(readAdvice));
success = true;
} finally {
if (success == false) {

View File

@ -66,15 +66,18 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class);
private final SegmentWriteState segmentWriteState;
private final ReadAdvice readAdvice;
private final IndexOutput meta, vectorData;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer)
public Lucene99FlatVectorsWriter(SegmentWriteState state, FlatVectorsScorer scorer, ReadAdvice readAdvice)
throws IOException {
super(scorer);
segmentWriteState = state;
this.readAdvice = readAdvice;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
@ -282,7 +285,7 @@ public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
// to perform random reads.
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), IOContext.DEFAULT.withReadAdvice(ReadAdvice.RANDOM));
tempVectorData.getName(), IOContext.DEFAULT.withReadAdvice(readAdvice));
// copy the temporary file vectors to the actual data file
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);

View File

@ -32,6 +32,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
@ -134,8 +135,14 @@ public class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
} else {
this.mergeExec = null;
}
/**
* Defines the format used for storing, reading, and merging vectors on disk.
* Flat formats enable random access to vectors based on their node ID, as recorded in the HNSW graph.
* To ensure consistent access, the {@link ReadAdvice#RANDOM} read advice is used.
*/
this.flatVectorsFormat =
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress);
new Lucene99ScalarQuantizedVectorsFormat(confidenceInterval, bits, compress, ReadAdvice.RANDOM);
}
@Override

View File

@ -29,6 +29,7 @@ import org.apache.lucene.index.MergeScheduler;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.TaskExecutor;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
@ -130,9 +131,13 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
*/
private final int beamWidth;
/** The format for storing, reading, and merging vectors on disk. */
/**
* Defines the format used for storing, reading, and merging vectors on disk.
* Flat formats enable random access to vectors based on their node ID, as recorded in the HNSW graph.
* To ensure consistent access, the {@link ReadAdvice#RANDOM} read advice is used.
*/
private static final FlatVectorsFormat flatVectorsFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), ReadAdvice.RANDOM);
private final int numMergeWorkers;
private final TaskExecutor mergeExec;

View File

@ -25,6 +25,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.ReadAdvice;
/**
* Format supporting vector quantization, storage, and retrieval
@ -50,8 +51,13 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
static final String META_EXTENSION = "vemq";
static final String VECTOR_DATA_EXTENSION = "veq";
/**
* Defines the format used for storing, reading, and merging raw vectors on disk.
* For this format, the {@link ReadAdvice#SEQUENTIAL} read advice is employed,
* as nearest neighbors are retrieved exclusively using a brute-force approach.
*/
private static final FlatVectorsFormat rawVectorFormat =
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer());
new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer(), ReadAdvice.SEQUENTIAL);
/** The minimum confidence interval */
private static final float MINIMUM_CONFIDENCE_INTERVAL = 0.9f;
@ -71,10 +77,15 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final ReadAdvice readAdvice;
/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
this(null, 7, false);
/**
* For this format, the {@link ReadAdvice#SEQUENTIAL} read advice is employed,
* as nearest neighbors are retrieved exclusively using a brute-force approach.
*/
this(null, 7, false, ReadAdvice.SEQUENTIAL);
}
/**
@ -91,7 +102,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
* during searching, at some decode speed penalty.
*/
public Lucene99ScalarQuantizedVectorsFormat(
Float confidenceInterval, int bits, boolean compress) {
Float confidenceInterval, int bits, boolean compress, ReadAdvice readAdvice) {
super(NAME);
if (confidenceInterval != null
&& confidenceInterval != DYNAMIC_CONFIDENCE_INTERVAL
@ -119,6 +130,7 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
this.readAdvice = readAdvice;
}
public static float calculateDefaultConfidenceInterval(int vectorDimension) {
@ -151,12 +163,13 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
bits,
compress,
rawVectorFormat.fieldsWriter(state),
flatVectorScorer);
flatVectorScorer,
readAdvice);
}
@Override
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99ScalarQuantizedVectorsReader(
state, rawVectorFormat.fieldsReader(state), flatVectorScorer);
state, rawVectorFormat.fieldsReader(state), flatVectorScorer, readAdvice);
}
}

View File

@ -62,14 +62,16 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
private final IndexInput quantizedVectorData;
private final FlatVectorsReader rawVectorsReader;
private final FieldInfos fieldInfos;
private final ReadAdvice readAdvice;
public Lucene99ScalarQuantizedVectorsReader(
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer)
SegmentReadState state, FlatVectorsReader rawVectorsReader, FlatVectorsScorer scorer, ReadAdvice readAdvice)
throws IOException {
super(scorer);
this.rawVectorsReader = rawVectorsReader;
this.fieldInfos = state.fieldInfos;
int versionMeta = -1;
this.readAdvice = readAdvice;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
@ -99,9 +101,7 @@ public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReade
versionMeta,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
// Quantized vectors are accessed randomly from their node ID stored in the HNSW
// graph.
state.context.withReadAdvice(ReadAdvice.RANDOM));
state.context.withReadAdvice(readAdvice));
success = true;
} finally {
if (success == false) {

View File

@ -54,6 +54,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.VectorUtil;
@ -101,13 +102,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final byte bits;
private final boolean compress;
private final int version;
private final ReadAdvice readAdvice;
private boolean finished;
public Lucene99ScalarQuantizedVectorsWriter(
SegmentWriteState state,
Float confidenceInterval,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
FlatVectorsScorer scorer,
ReadAdvice readAdvice)
throws IOException {
this(
state,
@ -116,7 +119,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
(byte) 7,
false,
rawVectorDelegate,
scorer);
scorer,
readAdvice);
if (confidenceInterval != null && confidenceInterval == 0) {
throw new IllegalArgumentException("confidenceInterval cannot be set to zero");
}
@ -128,7 +132,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
FlatVectorsScorer scorer,
ReadAdvice readAdvice)
throws IOException {
this(
state,
@ -137,7 +142,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
bits,
compress,
rawVectorDelegate,
scorer);
scorer,
readAdvice);
}
private Lucene99ScalarQuantizedVectorsWriter(
@ -147,7 +153,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
byte bits,
boolean compress,
FlatVectorsWriter rawVectorDelegate,
FlatVectorsScorer scorer)
FlatVectorsScorer scorer,
ReadAdvice readAdvice)
throws IOException {
super(scorer);
this.confidenceInterval = confidenceInterval;
@ -167,6 +174,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
this.rawVectorDelegate = rawVectorDelegate;
this.readAdvice = readAdvice;
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
@ -491,7 +499,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
segmentWriteState.directory.openInput(
tempQuantizedVectorData.getName(), segmentWriteState.context);
tempQuantizedVectorData.getName(), segmentWriteState.context.withReadAdvice(readAdvice));
quantizedVectorData.copyBytes(
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;

View File

@ -132,7 +132,6 @@ abstract class AbstractKnnVectorQuery extends Query {
if (scorer == null) {
return NO_RESULTS;
}
BitSet acceptDocs = createBitSet(scorer.iterator(), liveDocs, reader.maxDoc());
final int cost = acceptDocs.cardinality();
QueryTimeout queryTimeout = timeLimitingKnnCollectorManager.getQueryTimeout();

View File

@ -41,6 +41,7 @@ import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.ReadAdvice;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.VectorUtil;
@ -64,7 +65,7 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
}
format =
new Lucene99ScalarQuantizedVectorsFormat(
confidenceInterval, bits, bits == 4 ? random().nextBoolean() : false);
confidenceInterval, bits, bits == 4 ? random().nextBoolean() : false, random().nextBoolean() ? ReadAdvice.RANDOM : ReadAdvice.SEQUENTIAL);
super.setUp();
}
@ -198,7 +199,7 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99ScalarQuantizedVectorsFormat(0.9f, (byte) 4, false);
return new Lucene99ScalarQuantizedVectorsFormat(0.9f, (byte) 4, false, ReadAdvice.RANDOM);
}
};
String expectedPattern =
@ -212,16 +213,16 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
public void testLimits() {
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(1.1f, 7, false));
() -> new Lucene99ScalarQuantizedVectorsFormat(1.1f, 7, false, ReadAdvice.RANDOM));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, -1, false));
() -> new Lucene99ScalarQuantizedVectorsFormat(null, -1, false, ReadAdvice.RANDOM));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 5, false));
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 5, false, ReadAdvice.RANDOM));
expectThrows(
IllegalArgumentException.class,
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 9, false));
() -> new Lucene99ScalarQuantizedVectorsFormat(null, 9, false, ReadAdvice.RANDOM));
}
@Override