From a47ba3369f402690ea9418b6f9c4c0cf367eab8d Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Fri, 10 Nov 2023 14:05:19 -0500 Subject: [PATCH] Refactoring HNSW to use a new internal FlatVectorFormat (#12729) Currently the HNSW codec does too many things, it not only indexes vectors, but stores them and determines how to store them given the vector type. This PR extracts out the vector storage into a new format `Lucene99FlatVectorsFormat` and adds new base class called `FlatVectorsFormat`. This allows for some additional helper functions that allow an indexing codec (like HNSW) take advantage of the flat formats. Additionally, this PR refactors the new `Lucene99ScalarQuantizedVectorsFormat` to be a `FlatVectorsFormat`. Now, `Lucene99HnswVectorsFormat` is constructed with a `Lucene99FlatVectorsFormat` and a new `Lucene99HnswScalarQuantizedVectorsFormat` that uses `Lucene99ScalarQuantizedVectorsFormat` --- lucene/CHANGES.txt | 4 + .../lucene92/OffHeapFloatVectorValues.java | 8 +- .../lucene94/OffHeapByteVectorValues.java | 8 +- .../lucene94/OffHeapFloatVectorValues.java | 8 +- lucene/core/src/java/module-info.java | 4 +- .../lucene/codecs/FlatFieldVectorsWriter.java | 43 ++ .../lucene/codecs/FlatVectorsFormat.java | 39 ++ .../lucene/codecs/FlatVectorsReader.java | 92 ++++ .../lucene/codecs/FlatVectorsWriter.java | 74 +++ .../lucene95/OffHeapByteVectorValues.java | 2 - .../lucene95/OffHeapFloatVectorValues.java | 2 - .../lucene99/Lucene99FlatVectorsFormat.java | 98 ++++ .../lucene99/Lucene99FlatVectorsReader.java | 333 ++++++++++++ .../lucene99/Lucene99FlatVectorsWriter.java | 508 +++++++++++++++++ ...ene99HnswScalarQuantizedVectorsFormat.java | 159 ++++++ .../lucene99/Lucene99HnswVectorsFormat.java | 87 +-- .../lucene99/Lucene99HnswVectorsReader.java | 268 ++------- .../lucene99/Lucene99HnswVectorsWriter.java | 512 ++---------------- .../Lucene99ScalarQuantizedVectorsFormat.java | 37 +- .../Lucene99ScalarQuantizedVectorsReader.java | 287 +++++++++- .../Lucene99ScalarQuantizedVectorsWriter.java | 422 +++++++++++---- .../OffHeapQuantizedByteVectorValues.java | 8 +- .../ScalarQuantizedRandomVectorScorer.java | 5 +- .../CloseableRandomVectorScorerSupplier.java | 5 +- .../util/hnsw/RandomAccessVectorValues.java | 11 + .../lucene/util/hnsw/RandomVectorScorer.java | 73 ++- .../util/hnsw/RandomVectorScorerSupplier.java | 16 +- .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + ...estLucene99HnswQuantizedVectorsFormat.java | 10 +- .../TestLucene99HnswVectorsFormat.java | 2 +- ...tLucene99ScalarQuantizedVectorsFormat.java | 79 --- .../org/apache/lucene/index/TestKnnGraph.java | 17 +- .../lucene/util/hnsw/TestNeighborArray.java | 22 +- 33 files changed, 2229 insertions(+), 1015 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java delete mode 100644 lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 5851311f357..4631e21819b 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -184,6 +184,10 @@ New Features * GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat. (Patrick Zhai) +* GITHUB#12729: Add new Lucene99FlatVectorsFormat for writing vectors in a flat format and refactor + Lucene99ScalarQuantizedVectorsFormat & Lucene99HnswVectorsFormat to reuse the flat formats. + Additionally, this allows flat formats to be pluggable independent of HNSW. (Ben Trent) + Improvements --------------------- * GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 9ca66cd47a8..267ce15bf8b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -80,8 +80,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } } - abstract Bits getAcceptOrds(Bits acceptDocs); - static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { private int doc = -1; @@ -120,7 +118,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } } @@ -184,7 +182,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { return null; } @@ -256,7 +254,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return null; } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 82a635e9c46..d80f38ccbf4 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -89,8 +89,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } } - abstract Bits getAcceptOrds(Bits acceptDocs); - static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { private int doc = -1; @@ -129,7 +127,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } } @@ -196,7 +194,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { return null; } @@ -268,7 +266,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return null; } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index bff4b4bf9a1..1035fe5dd72 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -86,8 +86,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } } - abstract Bits getAcceptOrds(Bits acceptDocs); - static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { private int doc = -1; @@ -126,7 +124,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } } @@ -193,7 +191,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { return null; } @@ -265,7 +263,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return null; } } diff --git a/lucene/core/src/java/module-info.java b/lucene/core/src/java/module-info.java index 096c965dc98..2bddbdd81a9 100644 --- a/lucene/core/src/java/module-info.java +++ b/lucene/core/src/java/module-info.java @@ -16,7 +16,6 @@ */ import org.apache.lucene.codecs.lucene99.Lucene99Codec; -import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; /** Lucene Core. */ @SuppressWarnings("module") // the test framework is compiled after the core... @@ -70,7 +69,8 @@ module org.apache.lucene.core { provides org.apache.lucene.codecs.DocValuesFormat with org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; provides org.apache.lucene.codecs.KnnVectorsFormat with - Lucene99HnswVectorsFormat; + org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat, + org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; provides org.apache.lucene.codecs.PostingsFormat with org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat; provides org.apache.lucene.index.SortFieldProvider with diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java new file mode 100644 index 00000000000..679b3d3af2e --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/FlatFieldVectorsWriter.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs; + +/** + * Vectors' writer for a field + * + * @param an array type; the type of vectors to be written + * @lucene.experimental + */ +public abstract class FlatFieldVectorsWriter extends KnnFieldVectorsWriter { + + /** + * The delegate to write to, can be null When non-null, all vectors seen should be written to the + * delegate along with being written to the flat vectors. + */ + protected final KnnFieldVectorsWriter indexingDelegate; + + /** + * Sole constructor that expects some indexingDelegate. All vectors seen should be written to the + * delegate along with being written to the flat vectors. + * + * @param indexingDelegate the delegate to write to, can be null + */ + protected FlatFieldVectorsWriter(KnnFieldVectorsWriter indexingDelegate) { + this.indexingDelegate = indexingDelegate; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java new file mode 100644 index 00000000000..3bfb19ced57 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsFormat.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs; + +import java.io.IOException; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * Encodes/decodes per-document vectors + * + * @lucene.experimental + */ +public abstract class FlatVectorsFormat { + + /** Sole constructor */ + protected FlatVectorsFormat() {} + + /** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */ + public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException; + + /** Returns a {@link KnnVectorsReader} to read the vectors from the index. */ + public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java new file mode 100644 index 00000000000..eca0fc97209 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsReader.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs; + +import java.io.Closeable; +import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +/** + * Reads vectors from an index. When searching this reader, it iterates every vector in the index + * and scores them + * + *

This class is useful when: + * + *

    + *
  • the number of vectors is small + *
  • when used along side some additional indexing structure that can be used to better search + * the vectors (like HNSW). + *
+ * + * @lucene.experimental + */ +public abstract class FlatVectorsReader implements Closeable, Accountable { + + /** Sole constructor */ + protected FlatVectorsReader() {} + + /** + * Returns a {@link RandomVectorScorer} for the given field and target vector. + * + * @param field the field to search + * @param target the target vector + * @return a {@link RandomVectorScorer} for the given field and target vector. + * @throws IOException if an I/O error occurs when reading from the index. + */ + public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] target) + throws IOException; + + /** + * Returns a {@link RandomVectorScorer} for the given field and target vector. + * + * @param field the field to search + * @param target the target vector + * @return a {@link RandomVectorScorer} for the given field and target vector. + * @throws IOException if an I/O error occurs when reading from the index. + */ + public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target) + throws IOException; + + /** + * Checks consistency of this reader. + * + *

Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value + * against large data files. + * + * @lucene.internal + */ + public abstract void checkIntegrity() throws IOException; + + /** + * Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if + * the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is + * never {@code null}. + */ + public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException; + + /** + * Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if + * the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is + * never {@code null}. + */ + public abstract ByteVectorValues getByteVectorValues(String field) throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java new file mode 100644 index 00000000000..07cca250c0f --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/FlatVectorsWriter.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs; + +import java.io.Closeable; +import java.io.IOException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; + +/** + * Vectors' writer for a field that allows additional indexing logic to be implemented by the caller + * + * @lucene.experimental + */ +public abstract class FlatVectorsWriter implements Accountable, Closeable { + + /** Sole constructor */ + protected FlatVectorsWriter() {} + + /** + * Add a new field for indexing, allowing the user to provide a writer that the flat vectors + * writer can delegate to if additional indexing logic is required. + * + * @param fieldInfo fieldInfo of the field to add + * @param indexWriter the writer to delegate to, can be null + * @return a writer for the field + * @throws IOException if an I/O error occurs when adding the field + */ + public abstract FlatFieldVectorsWriter addField( + FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException; + + /** + * Write the field for merging, providing a scorer over the newly merged flat vectors. This way + * any additional merging logic can be implemented by the user of this class. + * + * @param fieldInfo fieldInfo of the field to merge + * @param mergeState mergeState of the segments to merge + * @return a scorer over the newly merged flat vectors, which should be closed as it holds a + * temporary file handle to read over the newly merged vectors + * @throws IOException if an I/O error occurs when merging + */ + public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException; + + /** Write field for merging */ + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState)); + } + + /** Called once at the end before close */ + public abstract void finish() throws IOException; + + /** Flush all buffered data on disk * */ + public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException; +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 37ef4727394..c11ed70f0b8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -94,8 +94,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues } } - public abstract Bits getAcceptOrds(Bits acceptDocs); - /** * Dense vector values that are stored off-heap. This is the most common case when every doc has a * vector. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 44ea057fc03..93cca6262d1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -88,8 +88,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues } } - public abstract Bits getAcceptOrds(Bits acceptDocs); - /** * Dense vector values that are stored off-heap. This is the most common case when every doc has a * vector. diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java new file mode 100644 index 00000000000..39fddb33653 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsFormat.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import org.apache.lucene.codecs.FlatVectorsFormat; +import org.apache.lucene.codecs.FlatVectorsReader; +import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.codecs.lucene90.IndexedDISI; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexOutput; + +/** + * Lucene 9.9 flat vector format, which encodes numeric vector values + * + *

.vec (vector data) file

+ * + *

For each field: + * + *

    + *
  • Vector data ordered by field, document ordinal, and vector dimension. When the + * vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each + * sample is stored as an IEEE float in little-endian byte order. + *
  • DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}, + * note that only in sparse case + *
  • OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note + * that only in sparse case + *
+ * + *

.vemf (vector metadata) file

+ * + *

For each field: + * + *

    + *
  • [int32] field number + *
  • [int32] vector similarity function ordinal + *
  • [vlong] offset to this field's vectors in the .vec file + *
  • [vlong] length of this field's vectors, in bytes + *
  • [vint] dimension of this field's vectors + *
  • [int] the number of documents having values for this field + *
  • [int8] if equals to -1, dense – all documents have values for a field. If equals to + * 0, sparse – some documents missing values. + *
  • DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)} + *
  • OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note + * that only in sparse case + *
+ * + * @lucene.experimental + */ +public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat { + + static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData"; + static final String META_EXTENSION = "vemf"; + static final String VECTOR_DATA_EXTENSION = "vec"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16; + + /** Constructs a format */ + public Lucene99FlatVectorsFormat() { + super(); + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99FlatVectorsWriter(state); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99FlatVectorsReader(state); + } + + @Override + public String toString() { + return "Lucene99FlatVectorsFormat()"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java new file mode 100644 index 00000000000..89dbcd8ff6e --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsReader.java @@ -0,0 +1,333 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatVectorsReader; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.RandomVectorScorer; + +/** + * Reads vectors from the index segments. + * + * @lucene.experimental + */ +public final class Lucene99FlatVectorsReader extends FlatVectorsReader { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class); + + private final Map fields = new HashMap<>(); + private final IndexInput vectorData; + + Lucene99FlatVectorsReader(SegmentReadState state) throws IOException { + int versionMeta = readMetadata(state); + boolean success = false; + try { + vectorData = + openDataInput( + state, + versionMeta, + Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION, + Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + private int readMetadata(SegmentReadState state) throws IOException { + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION); + int versionMeta = -1; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene99FlatVectorsFormat.META_CODEC_NAME, + Lucene99FlatVectorsFormat.VERSION_START, + Lucene99FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(meta, priorE); + } + } + return versionMeta; + } + + private static IndexInput openDataInput( + SegmentReadState state, int versionMeta, String fileExtension, String codecName) + throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, state.context); + boolean success = false; + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene99FlatVectorsFormat.VERSION_START, + Lucene99FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { + int dimension = info.getVectorDimension(); + if (dimension != fieldEntry.dimension) { + throw new IllegalStateException( + "Inconsistent vector dimension for field=\"" + + info.name + + "\"; " + + dimension + + " != " + + fieldEntry.dimension); + } + + int byteSize = + switch (info.getVectorEncoding()) { + case BYTE -> Byte.BYTES; + case FLOAT32 -> Float.BYTES; + }; + long vectorBytes = Math.multiplyExact((long) dimension, byteSize); + long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size); + if (numBytes != fieldEntry.vectorDataLength) { + throw new IllegalStateException( + "Vector data length " + + fieldEntry.vectorDataLength + + " not matching size=" + + fieldEntry.size + + " * dim=" + + dimension + + " * byteSize=" + + byteSize + + " = " + + numBytes); + } + } + + private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { + int similarityFunctionId = input.readInt(); + if (similarityFunctionId < 0 + || similarityFunctionId >= VectorSimilarityFunction.values().length) { + throw new CorruptIndexException( + "Invalid similarity function id: " + similarityFunctionId, input); + } + return VectorSimilarityFunction.values()[similarityFunctionId]; + } + + private VectorEncoding readVectorEncoding(DataInput input) throws IOException { + int encodingId = input.readInt(); + if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { + throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); + } + return VectorEncoding.values()[encodingId]; + } + + private FieldEntry readField(IndexInput input) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + return new FieldEntry(input, vectorEncoding, similarityFunction); + } + + @Override + public long ramBytesUsed() { + return Lucene99FlatVectorsReader.SHALLOW_SIZE + + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + } + + @Override + public void checkIntegrity() throws IOException { + CodecUtil.checksumEntireFile(vectorData); + } + + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + return OffHeapFloatVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { + throw new IllegalArgumentException( + "field=\"" + + field + + "\" is encoded as: " + + fieldEntry.vectorEncoding + + " expected: " + + VectorEncoding.FLOAT32); + } + return OffHeapByteVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } + return RandomVectorScorer.createFloats( + OffHeapFloatVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData), + fieldEntry.similarityFunction, + target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { + return null; + } + return RandomVectorScorer.createBytes( + OffHeapByteVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.vectorEncoding, + fieldEntry.dimension, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + vectorData), + fieldEntry.similarityFunction, + target); + } + + @Override + public void close() throws IOException { + IOUtils.close(vectorData); + } + + private static class FieldEntry implements Accountable { + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class); + final VectorSimilarityFunction similarityFunction; + final VectorEncoding vectorEncoding; + final int dimension; + final long vectorDataOffset; + final long vectorDataLength; + final int size; + final OrdToDocDISIReaderConfiguration ordToDoc; + + FieldEntry( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + this.similarityFunction = similarityFunction; + this.vectorEncoding = vectorEncoding; + vectorDataOffset = input.readVLong(); + vectorDataLength = input.readVLong(); + dimension = input.readVInt(); + size = input.readInt(); + ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java new file mode 100644 index 00000000000..e386fe67ae3 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.Closeable; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.DocsWithFieldSet; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; +import org.apache.lucene.util.hnsw.RandomVectorScorer; +import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; + +/** + * Writes vector values to index segments. + * + * @lucene.experimental + */ +public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter { + + private static final long SHALLLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class); + + private final SegmentWriteState segmentWriteState; + private final IndexOutput meta, vectorData; + + private final List> fields = new ArrayList<>(); + private boolean finished; + + Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException { + segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION); + + String vectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION); + + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + vectorData = state.directory.createOutput(vectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene99FlatVectorsFormat.META_CODEC_NAME, + Lucene99FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + vectorData, + Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene99FlatVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public FlatFieldVectorsWriter addField( + FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException { + FieldWriter newField = FieldWriter.create(fieldInfo, indexWriter); + fields.add(newField); + return newField; + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + for (FieldWriter field : fields) { + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + } + } + + @Override + public void finish() throws IOException { + if (finished) { + throw new IllegalStateException("already finished"); + } + finished = true; + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } + if (vectorData != null) { + CodecUtil.writeFooter(vectorData); + } + } + + @Override + public long ramBytesUsed() { + long total = SHALLLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; + } + + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectors(fieldData); + case FLOAT32 -> writeFloat32Vectors(fieldData); + } + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta( + fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField); + } + + private void writeFloat32Vectors(FieldWriter fieldData) throws IOException { + final ByteBuffer buffer = + ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (Object v : fieldData.vectors) { + buffer.asFloatBuffer().put((float[]) v); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + } + + private void writeByteVectors(FieldWriter fieldData) throws IOException { + for (Object v : fieldData.vectors) { + byte[] vector = (byte[]) v; + vectorData.writeBytes(vector, vector.length); + } + } + + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) + throws IOException { + final int[] docIdOffsets = new int[sortMap.size()]; + int offset = 1; // 0 means no vector for this (field, document) + DocIdSetIterator iterator = fieldData.docsWithField.iterator(); + for (int docID = iterator.nextDoc(); + docID != DocIdSetIterator.NO_MORE_DOCS; + docID = iterator.nextDoc()) { + int newDocID = sortMap.oldToNew(docID); + docIdOffsets[newDocID] = offset++; + } + DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); + final int[] ordMap = new int[offset - 1]; // new ord to old ord + int ord = 0; + int doc = 0; + for (int docIdOffset : docIdOffsets) { + if (docIdOffset != 0) { + ordMap[ord] = docIdOffset - 1; + newDocsWithField.add(doc); + ord++; + } + doc++; + } + + // write vector values + long vectorDataOffset = + switch (fieldData.fieldInfo.getVectorEncoding()) { + case BYTE -> writeSortedByteVectors(fieldData, ordMap); + case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + + writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField); + } + + private long writeSortedFloat32Vectors(FieldWriter fieldData, int[] ordMap) + throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + final ByteBuffer buffer = + ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int ordinal : ordMap) { + float[] vector = (float[]) fieldData.vectors.get(ordinal); + buffer.asFloatBuffer().put(vector); + vectorData.writeBytes(buffer.array(), buffer.array().length); + } + return vectorDataOffset; + } + + private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + for (int ordinal : ordMap) { + byte[] vector = (byte[]) fieldData.vectors.get(ordinal); + vectorData.writeBytes(vector, vector.length); + } + return vectorDataOffset; + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + // No need to use temporary file as we don't have to re-open for reading + DocsWithFieldSet docsWithField = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData( + vectorData, + KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); + case FLOAT32 -> writeVectorData( + vectorData, + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + }; + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + docsWithField); + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); + IndexOutput tempVectorData = + segmentWriteState.directory.createTempOutput( + vectorData.getName(), "temp", segmentWriteState.context); + IndexInput vectorDataInput = null; + boolean success = false; + try { + // write the vector data to a temporary file + DocsWithFieldSet docsWithField = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> writeByteVectorData( + tempVectorData, + KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); + case FLOAT32 -> writeVectorData( + tempVectorData, + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + }; + CodecUtil.writeFooter(tempVectorData); + IOUtils.close(tempVectorData); + + // copy the temporary file vectors to the actual data file + vectorDataInput = + segmentWriteState.directory.openInput( + tempVectorData.getName(), segmentWriteState.context); + vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); + CodecUtil.retrieveChecksum(vectorDataInput); + long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + docsWithField); + success = true; + final IndexInput finalVectorDataInput = vectorDataInput; + final RandomVectorScorerSupplier randomVectorScorerSupplier = + switch (fieldInfo.getVectorEncoding()) { + case BYTE -> RandomVectorScorerSupplier.createBytes( + new OffHeapByteVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * Byte.BYTES), + fieldInfo.getVectorSimilarityFunction()); + case FLOAT32 -> RandomVectorScorerSupplier.createFloats( + new OffHeapFloatVectorValues.DenseOffHeapVectorValues( + fieldInfo.getVectorDimension(), + docsWithField.cardinality(), + finalVectorDataInput, + fieldInfo.getVectorDimension() * Float.BYTES), + fieldInfo.getVectorSimilarityFunction()); + }; + return new FlatCloseableRandomVectorScorerSupplier( + () -> { + IOUtils.close(finalVectorDataInput); + segmentWriteState.directory.deleteFile(tempVectorData.getName()); + }, + docsWithField.cardinality(), + randomVectorScorerSupplier); + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData); + IOUtils.deleteFilesIgnoringExceptions( + segmentWriteState.directory, tempVectorData.getName()); + } + } + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + + // write docIDs + int count = docsWithField.cardinality(); + meta.writeInt(count); + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); + } + + /** + * Writes the byte vector values to the output and returns a set of documents that contains + * vectors. + */ + private static DocsWithFieldSet writeByteVectorData( + IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + for (int docV = byteVectorValues.nextDoc(); + docV != NO_MORE_DOCS; + docV = byteVectorValues.nextDoc()) { + // write vector + byte[] binaryValue = byteVectorValues.vectorValue(); + assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; + output.writeBytes(binaryValue, binaryValue.length); + docsWithField.add(docV); + } + return docsWithField; + } + + /** + * Writes the vector values to the output and returns a set of documents that contains vectors. + */ + private static DocsWithFieldSet writeVectorData( + IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { + DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + ByteBuffer buffer = + ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) + .order(ByteOrder.LITTLE_ENDIAN); + for (int docV = floatVectorValues.nextDoc(); + docV != NO_MORE_DOCS; + docV = floatVectorValues.nextDoc()) { + // write vector + float[] value = floatVectorValues.vectorValue(); + buffer.asFloatBuffer().put(value); + output.writeBytes(buffer.array(), buffer.limit()); + docsWithField.add(docV); + } + return docsWithField; + } + + @Override + public void close() throws IOException { + IOUtils.close(meta, vectorData); + } + + private abstract static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; + private final int dim; + private final DocsWithFieldSet docsWithField; + private final List vectors; + + private int lastDocID = -1; + + @SuppressWarnings("unchecked") + static FieldWriter create(FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) { + int dim = fieldInfo.getVectorDimension(); + return switch (fieldInfo.getVectorEncoding()) { + case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>( + fieldInfo, (KnnFieldVectorsWriter) indexWriter) { + @Override + public byte[] copyValue(byte[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>( + fieldInfo, (KnnFieldVectorsWriter) indexWriter) { + @Override + public float[] copyValue(float[] value) { + return ArrayUtil.copyOfSubArray(value, 0, dim); + } + }; + }; + } + + FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) { + super(indexWriter); + this.fieldInfo = fieldInfo; + this.dim = fieldInfo.getVectorDimension(); + this.docsWithField = new DocsWithFieldSet(); + vectors = new ArrayList<>(); + } + + @Override + public void addValue(int docID, T vectorValue) throws IOException { + if (docID == lastDocID) { + throw new IllegalArgumentException( + "VectorValuesField \"" + + fieldInfo.name + + "\" appears more than once in this document (only one value is allowed per field)"); + } + assert docID > lastDocID; + T copy = copyValue(vectorValue); + docsWithField.add(docID); + vectors.add(copy); + lastDocID = docID; + if (indexingDelegate != null) { + indexingDelegate.addValue(docID, copy); + } + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_RAM_BYTES_USED; + if (indexingDelegate != null) { + size += indexingDelegate.ramBytesUsed(); + } + if (vectors.size() == 0) return size; + return size + + docsWithField.ramBytesUsed() + + (long) vectors.size() + * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + + (long) vectors.size() + * fieldInfo.getVectorDimension() + * fieldInfo.getVectorEncoding().byteSize; + } + } + + static final class FlatCloseableRandomVectorScorerSupplier + implements CloseableRandomVectorScorerSupplier { + + private final RandomVectorScorerSupplier supplier; + private final Closeable onClose; + private final int numVectors; + + FlatCloseableRandomVectorScorerSupplier( + Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) { + this.onClose = onClose; + this.supplier = supplier; + this.numVectors = numVectors; + } + + @Override + public RandomVectorScorer scorer(int ord) throws IOException { + return supplier.scorer(ord); + } + + @Override + public RandomVectorScorerSupplier copy() throws IOException { + return supplier.copy(); + } + + @Override + public void close() throws IOException { + onClose.close(); + } + + @Override + public int totalVectorCount() { + return numVectors; + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java new file mode 100644 index 00000000000..23d607a1c77 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswScalarQuantizedVectorsFormat.java @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.lucene.codecs.lucene99; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH; +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN; + +import java.io.IOException; +import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.FlatVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.util.hnsw.HnswGraph; + +/** + * Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting + * the documents having values. The graph is used to power HNSW search. The format consists of two + * files, and uses {@link Lucene99ScalarQuantizedVectorsFormat} to store the actual vectors: For + * details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}. + * + * @lucene.experimental + */ +public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat { + + /** + * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to + * {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. + */ + private final int maxConn; + + /** + * The number of candidate neighbors to track while searching the graph for each newly inserted + * node. Defaults to to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link + * HnswGraph} for details. + */ + private final int beamWidth; + + /** The format for storing, reading, merging vectors on disk */ + private final FlatVectorsFormat flatVectorsFormat; + + private final int numMergeWorkers; + private final ExecutorService mergeExec; + + /** Constructs a format using default graph construction parameters */ + public Lucene99HnswScalarQuantizedVectorsFormat() { + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null); + } + + /** + * Constructs a format using the given graph construction parameters. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + */ + public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) { + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null); + } + + /** + * Constructs a format using the given graph construction parameters and scalar quantization. + * + * @param maxConn the maximum number of connections to a node in the HNSW graph + * @param beamWidth the size of the queue maintained during graph construction. + * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If + * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec + * @param configuredQuantile the quantile for scalar quantizing the vectors, when `null` it is + * calculated based on the vector field dimensions. + * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are + * generated by this format to do the merge + */ + public Lucene99HnswScalarQuantizedVectorsFormat( + int maxConn, + int beamWidth, + int numMergeWorkers, + Float configuredQuantile, + ExecutorService mergeExec) { + super("Lucene99HnswScalarQuantizedVectorsFormat"); + if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { + throw new IllegalArgumentException( + "maxConn must be positive and less than or equal to" + + MAXIMUM_MAX_CONN + + "; maxConn=" + + maxConn); + } + if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) { + throw new IllegalArgumentException( + "beamWidth must be positive and less than or equal to" + + MAXIMUM_BEAM_WIDTH + + "; beamWidth=" + + beamWidth); + } + this.maxConn = maxConn; + this.beamWidth = beamWidth; + if (numMergeWorkers > 1 && mergeExec == null) { + throw new IllegalArgumentException( + "No executor service passed in when " + numMergeWorkers + " merge workers are requested"); + } + if (numMergeWorkers == 1 && mergeExec != null) { + throw new IllegalArgumentException( + "No executor service is needed as we'll use single thread to merge"); + } + this.numMergeWorkers = numMergeWorkers; + this.mergeExec = mergeExec; + this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(configuredQuantile); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99HnswVectorsWriter( + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=" + + maxConn + + ", beamWidth=" + + beamWidth + + ", flatVectorFormat=" + + flatVectorsFormat + + ")"; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java index 038b75e4c48..85d65df55b9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsFormat.java @@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; import java.util.concurrent.ExecutorService; +import org.apache.lucene.codecs.FlatVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -30,23 +31,9 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.hnsw.HnswGraph; /** - * Lucene 9.9 vector format, which encodes numeric vector values and an optional associated graph - * connecting the documents having values. The graph is used to power HNSW search. The format - * consists of three files, with an optional fourth file: - * - *

.vec (vector data) file

- * - *

For each field: - * - *

    - *
  • Vector data ordered by field, document ordinal, and vector dimension. When the - * vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each - * sample is stored as an IEEE float in little-endian byte order. - *
  • DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}, - * note that only in sparse case - *
  • OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note - * that only in sparse case - *
+ * Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting + * the documents having values. The graph is used to power HNSW search. The format consists of two + * files, and requires a {@link FlatVectorsFormat} to store the actual vectors: * *

.vex (vector index)

* @@ -74,14 +61,6 @@ import org.apache.lucene.util.hnsw.HnswGraph; *
    *
  • [int32] field number *
  • [int32] vector similarity function ordinal - *
  • [byte] if equals to 1 indicates if the field is for quantized vectors - *
  • [int32] if quantized: the configured quantile float int bits. - *
  • [int32] if quantized: the calculated lower quantile float int32 bits. - *
  • [int32] if quantized: the calculated upper quantile float int32 bits. - *
  • [vlong] if quantized: offset to this field's vectors in the .veq file - *
  • [vlong] if quantized: length of this field's vectors, in bytes in the .veq file - *
  • [vlong] offset to this field's vectors in the .vec file - *
  • [vlong] length of this field's vectors, in bytes *
  • [vlong] offset to this field's index in the .vex file *
  • [vlong] length of this field's index data, in bytes *
  • [vint] dimension of this field's vectors @@ -101,29 +80,13 @@ import org.apache.lucene.util.hnsw.HnswGraph; *
* * - *

.veq (quantized vector data) file

- * - *

For each field: - * - *

    - *
  • Vector data ordered by field, document ordinal, and vector dimension. Each vector dimension - * is stored as a single byte and every vector has a single float32 value for scoring - * corrections. - *
  • DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}, - * note that only in sparse case - *
  • OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note - * that only in sparse case - *
- * * @lucene.experimental */ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta"; - static final String VECTOR_DATA_CODEC_NAME = "Lucene99HnswVectorsFormatData"; static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex"; static final String META_EXTENSION = "vem"; - static final String VECTOR_DATA_EXTENSION = "vec"; static final String VECTOR_INDEX_EXTENSION = "vex"; public static final int VERSION_START = 0; @@ -135,7 +98,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { *

NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large * numbers here will use an inordinate amount of heap */ - private static final int MAXIMUM_MAX_CONN = 512; + static final int MAXIMUM_MAX_CONN = 512; /** Default number of maximum connections per node */ public static final int DEFAULT_MAX_CONN = 16; @@ -145,7 +108,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { * maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 = * 3200` */ - private static final int MAXIMUM_BEAM_WIDTH = 3200; + static final int MAXIMUM_BEAM_WIDTH = 3200; /** * Default number of the size of the queue maintained while searching during a graph construction. @@ -170,20 +133,15 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { */ private final int beamWidth; - /** Should this codec scalar quantize float32 vectors and use this format */ - private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat; + /** The format for storing, reading, merging vectors on disk */ + private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat(); private final int numMergeWorkers; private final ExecutorService mergeExec; /** Constructs a format using default graph construction parameters */ public Lucene99HnswVectorsFormat() { - this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null); - } - - public Lucene99HnswVectorsFormat( - int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) { - this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null); + this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null); } /** @@ -193,7 +151,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { * @param beamWidth the size of the queue maintained during graph construction. */ public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) { - this(maxConn, beamWidth, null); + this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null); } /** @@ -201,18 +159,13 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { * * @param maxConn the maximum number of connections to a node in the HNSW graph * @param beamWidth the size of the queue maintained during graph construction. - * @param scalarQuantize the scalar quantization format * @param numMergeWorkers number of workers (threads) that will be used when doing merge. If * larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec * @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are * generated by this format to do the merge */ public Lucene99HnswVectorsFormat( - int maxConn, - int beamWidth, - Lucene99ScalarQuantizedVectorsFormat scalarQuantize, - int numMergeWorkers, - ExecutorService mergeExec) { + int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) { super("Lucene99HnswVectorsFormat"); if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) { throw new IllegalArgumentException( @@ -228,6 +181,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { + "; beamWidth=" + beamWidth); } + this.maxConn = maxConn; + this.beamWidth = beamWidth; if (numMergeWorkers > 1 && mergeExec == null) { throw new IllegalArgumentException( "No executor service passed in when " + numMergeWorkers + " merge workers are requested"); @@ -236,9 +191,6 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { throw new IllegalArgumentException( "No executor service is needed as we'll use single thread to merge"); } - this.maxConn = maxConn; - this.beamWidth = beamWidth; - this.scalarQuantizedVectorsFormat = scalarQuantize; this.numMergeWorkers = numMergeWorkers; this.mergeExec = mergeExec; } @@ -246,12 +198,17 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { return new Lucene99HnswVectorsWriter( - state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec); + state, + maxConn, + beamWidth, + flatVectorsFormat.fieldsWriter(state), + numMergeWorkers, + mergeExec); } @Override public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new Lucene99HnswVectorsReader(state); + return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state)); } @Override @@ -265,8 +222,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat { + maxConn + ", beamWidth=" + beamWidth - + ", quantizer=" - + (scalarQuantizedVectorsFormat == null ? "none" : scalarQuantizedVectorsFormat.toString()) + + ", flatVectorFormat=" + + flatVectorsFormat + ")"; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 2e81f371f05..47f1b726527 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -24,11 +24,9 @@ import java.util.Arrays; import java.util.HashMap; import java.util.Map; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatVectorsReader; import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; -import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; -import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; @@ -67,49 +65,14 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader private final FieldInfos fieldInfos; private final Map fields = new HashMap<>(); - private final IndexInput vectorData; private final IndexInput vectorIndex; - private final IndexInput quantizedVectorData; - private final Lucene99ScalarQuantizedVectorsReader quantizedVectorsReader; + private final FlatVectorsReader flatVectorsReader; - Lucene99HnswVectorsReader(SegmentReadState state) throws IOException { - this.fieldInfos = state.fieldInfos; - int versionMeta = readMetadata(state); + Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader) + throws IOException { + this.flatVectorsReader = flatVectorsReader; boolean success = false; - try { - vectorData = - openDataInput( - state, - versionMeta, - Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION, - Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME); - vectorIndex = - openDataInput( - state, - versionMeta, - Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION, - Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME); - if (fields.values().stream().anyMatch(FieldEntry::hasQuantizedVectors)) { - quantizedVectorData = - openDataInput( - state, - versionMeta, - Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION, - Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME); - quantizedVectorsReader = new Lucene99ScalarQuantizedVectorsReader(quantizedVectorData); - } else { - quantizedVectorData = null; - quantizedVectorsReader = null; - } - success = true; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(this); - } - } - } - - private int readMetadata(SegmentReadState state) throws IOException { + this.fieldInfos = state.fieldInfos; String metaFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); @@ -129,10 +92,30 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader } catch (Throwable exception) { priorE = exception; } finally { - CodecUtil.checkFooter(meta, priorE); + try { + CodecUtil.checkFooter(meta, priorE); + success = true; + } finally { + if (success == false) { + IOUtils.close(flatVectorsReader); + } + } + } + } + success = false; + try { + vectorIndex = + openDataInput( + state, + versionMeta, + Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION, + Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); } } - return versionMeta; } private static IndexInput openDataInput( @@ -194,31 +177,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader + " != " + fieldEntry.dimension); } - - int byteSize = - switch (info.getVectorEncoding()) { - case BYTE -> Byte.BYTES; - case FLOAT32 -> Float.BYTES; - }; - long vectorBytes = Math.multiplyExact((long) dimension, byteSize); - long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size); - if (numBytes != fieldEntry.vectorDataLength) { - throw new IllegalStateException( - "Vector data length " - + fieldEntry.vectorDataLength - + " not matching size=" - + fieldEntry.size - + " * dim=" - + dimension - + " * byteSize=" - + byteSize - + " = " - + numBytes); - } - if (fieldEntry.hasQuantizedVectors()) { - Lucene99ScalarQuantizedVectorsReader.validateFieldEntry( - info, fieldEntry.dimension, fieldEntry.size, fieldEntry.quantizedVectorDataLength); - } } private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { @@ -249,58 +207,24 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader public long ramBytesUsed() { return Lucene99HnswVectorsReader.SHALLOW_SIZE + RamUsageEstimator.sizeOfMap( - fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)) + + flatVectorsReader.ramBytesUsed(); } @Override public void checkIntegrity() throws IOException { - CodecUtil.checksumEntireFile(vectorData); + flatVectorsReader.checkIntegrity(); CodecUtil.checksumEntireFile(vectorIndex); - if (quantizedVectorsReader != null) { - quantizedVectorsReader.checkIntegrity(); - } } @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + VectorEncoding.FLOAT32); - } - return OffHeapFloatVectorValues.load( - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData); + return flatVectorsReader.getFloatVectorValues(field); } @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) { - throw new IllegalArgumentException( - "field=\"" - + field - + "\" is encoded as: " - + fieldEntry.vectorEncoding - + " expected: " - + VectorEncoding.FLOAT32); - } - return OffHeapByteVectorValues.load( - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData); + return flatVectorsReader.getByteVectorValues(field); } @Override @@ -313,42 +237,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { return; } - if (fieldEntry.hasQuantizedVectors()) { - OffHeapQuantizedByteVectorValues vectorValues = - quantizedVectorsReader.getQuantizedVectorValues( - fieldEntry.quantizedOrdToDoc, - fieldEntry.dimension, - fieldEntry.size, - fieldEntry.quantizedVectorDataOffset, - fieldEntry.quantizedVectorDataLength); - if (vectorValues == null) { - return; - } - RandomVectorScorer scorer = - new ScalarQuantizedRandomVectorScorer( - fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target); - HnswGraphSearcher.search( - scorer, - new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); - } else { - OffHeapFloatVectorValues vectorValues = - OffHeapFloatVectorValues.load( - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData); - RandomVectorScorer scorer = - RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target); - HnswGraphSearcher.search( - scorer, - new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), - getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); - } + RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); + HnswGraphSearcher.search( + scorer, + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc), + getGraph(fieldEntry), + scorer.getAcceptOrds(acceptDocs)); } @Override @@ -361,22 +255,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader || fieldEntry.vectorEncoding != VectorEncoding.BYTE) { return; } - - OffHeapByteVectorValues vectorValues = - OffHeapByteVectorValues.load( - fieldEntry.ordToDoc, - fieldEntry.vectorEncoding, - fieldEntry.dimension, - fieldEntry.vectorDataOffset, - fieldEntry.vectorDataLength, - vectorData); - RandomVectorScorer scorer = - RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target); + RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target); HnswGraphSearcher.search( scorer, - new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), + new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + scorer.getAcceptOrds(acceptDocs)); } @Override @@ -399,32 +283,23 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader @Override public void close() throws IOException { - IOUtils.close(vectorData, vectorIndex, quantizedVectorData); + IOUtils.close(flatVectorsReader, vectorIndex); } @Override - public OffHeapQuantizedByteVectorValues getQuantizedVectorValues(String field) - throws IOException { - FieldEntry fieldEntry = fields.get(field); - if (fieldEntry == null || fieldEntry.hasQuantizedVectors() == false) { - return null; + public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException { + if (flatVectorsReader instanceof QuantizedVectorsReader) { + return ((QuantizedVectorsReader) flatVectorsReader).getQuantizedVectorValues(field); } - assert quantizedVectorsReader != null && fieldEntry.quantizedOrdToDoc != null; - return quantizedVectorsReader.getQuantizedVectorValues( - fieldEntry.quantizedOrdToDoc, - fieldEntry.dimension, - fieldEntry.size, - fieldEntry.quantizedVectorDataOffset, - fieldEntry.quantizedVectorDataLength); + return null; } @Override - public ScalarQuantizer getQuantizationState(String fieldName) { - FieldEntry field = fields.get(fieldName); - if (field == null || field.hasQuantizedVectors() == false) { - return null; + public ScalarQuantizer getQuantizationState(String field) { + if (flatVectorsReader instanceof QuantizedVectorsReader) { + return ((QuantizedVectorsReader) flatVectorsReader).getQuantizationState(field); } - return field.scalarQuantizer; + return null; } static class FieldEntry implements Accountable { @@ -432,8 +307,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class); final VectorSimilarityFunction similarityFunction; final VectorEncoding vectorEncoding; - final long vectorDataOffset; - final long vectorDataLength; final long vectorIndexOffset; final long vectorIndexLength; final int M; @@ -446,13 +319,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader final long offsetsOffset; final int offsetsBlockShift; final long offsetsLength; - final OrdToDocDISIReaderConfiguration ordToDoc; - - final float configuredQuantile, lowerQuantile, upperQuantile; - final long quantizedVectorDataOffset, quantizedVectorDataLength; - final ScalarQuantizer scalarQuantizer; - final boolean isQuantized; - final OrdToDocDISIReaderConfiguration quantizedOrdToDoc; FieldEntry( IndexInput input, @@ -461,36 +327,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader throws IOException { this.similarityFunction = similarityFunction; this.vectorEncoding = vectorEncoding; - this.isQuantized = input.readByte() == 1; - // Has int8 quantization - if (isQuantized) { - configuredQuantile = Float.intBitsToFloat(input.readInt()); - lowerQuantile = Float.intBitsToFloat(input.readInt()); - upperQuantile = Float.intBitsToFloat(input.readInt()); - quantizedVectorDataOffset = input.readVLong(); - quantizedVectorDataLength = input.readVLong(); - scalarQuantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, configuredQuantile); - } else { - configuredQuantile = -1; - lowerQuantile = -1; - upperQuantile = -1; - quantizedVectorDataOffset = -1; - quantizedVectorDataLength = -1; - scalarQuantizer = null; - } - vectorDataOffset = input.readVLong(); - vectorDataLength = input.readVLong(); vectorIndexOffset = input.readVLong(); vectorIndexLength = input.readVLong(); dimension = input.readVInt(); size = input.readInt(); - if (isQuantized) { - quantizedOrdToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); - } else { - quantizedOrdToDoc = null; - } - ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); - // read nodes by level M = input.readVInt(); numLevels = input.readVInt(); @@ -526,16 +366,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader return size; } - boolean hasQuantizedVectors() { - return isQuantized; - } - @Override public long ramBytesUsed() { return SHALLOW_SIZE + Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum() - + RamUsageEstimator.sizeOf(ordToDoc) - + (quantizedOrdToDoc == null ? 0 : RamUsageEstimator.sizeOf(quantizedOrdToDoc)) + RamUsageEstimator.sizeOf(offsetsMeta); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index 691a7300c59..ec9909e9698 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -18,40 +18,27 @@ package org.apache.lucene.codecs.lucene99; import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.concurrent.ExecutorService; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter; -import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; -import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; -import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.RamUsageEstimator; -import org.apache.lucene.util.ScalarQuantizer; import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier; import org.apache.lucene.util.hnsw.ConcurrentHnswMerger; import org.apache.lucene.util.hnsw.HnswGraph; @@ -62,7 +49,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; -import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -73,11 +59,13 @@ import org.apache.lucene.util.packed.DirectMonotonicWriter; */ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { + private static final long SHALLOW_RAM_BYTES_USED = + RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsWriter.class); private final SegmentWriteState segmentWriteState; - private final IndexOutput meta, vectorData, quantizedVectorData, vectorIndex; + private final IndexOutput meta, vectorIndex; private final int M; private final int beamWidth; - private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter; + private final FlatVectorsWriter flatVectorWriter; private final int numMergeWorkers; private final ExecutorService mergeExec; @@ -88,42 +76,30 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { SegmentWriteState state, int M, int beamWidth, - Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat, + FlatVectorsWriter flatVectorWriter, int numMergeWorkers, ExecutorService mergeExec) throws IOException { this.M = M; + this.flatVectorWriter = flatVectorWriter; this.beamWidth = beamWidth; this.numMergeWorkers = numMergeWorkers; this.mergeExec = mergeExec; segmentWriteState = state; + String metaFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION); - String vectorDataFileName = - IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION); - String indexDataFileName = IndexFileNames.segmentFileName( state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION); - final String quantizedVectorDataFileName = - quantizedVectorsFormat != null - ? IndexFileNames.segmentFileName( - state.segmentInfo.name, - state.segmentSuffix, - Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION) - : null; boolean success = false; try { meta = state.directory.createOutput(metaFileName, state.context); - vectorData = state.directory.createOutput(vectorDataFileName, state.context); vectorIndex = state.directory.createOutput(indexDataFileName, state.context); CodecUtil.writeIndexHeader( @@ -132,34 +108,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { Lucene99HnswVectorsFormat.VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); - CodecUtil.writeIndexHeader( - vectorData, - Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME, - Lucene99HnswVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); CodecUtil.writeIndexHeader( vectorIndex, Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME, Lucene99HnswVectorsFormat.VERSION_CURRENT, state.segmentInfo.getId(), state.segmentSuffix); - if (quantizedVectorDataFileName != null) { - quantizedVectorData = - state.directory.createOutput(quantizedVectorDataFileName, state.context); - CodecUtil.writeIndexHeader( - quantizedVectorData, - Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME, - Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - quantizedVectorsWriter = - new Lucene99ScalarQuantizedVectorsWriter( - quantizedVectorData, quantizedVectorsFormat.quantile); - } else { - quantizedVectorData = null; - quantizedVectorsWriter = null; - } success = true; } finally { if (success == false) { @@ -170,34 +124,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { @Override public KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedVectorFieldWriter = - null; - // Quantization only supports FLOAT32 for now - if (quantizedVectorsWriter != null - && fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - quantizedVectorFieldWriter = - quantizedVectorsWriter.addField(fieldInfo, segmentWriteState.infoStream); - } FieldWriter newField = - FieldWriter.create( - fieldInfo, M, beamWidth, segmentWriteState.infoStream, quantizedVectorFieldWriter); + FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream); fields.add(newField); - return newField; + return flatVectorWriter.addField(fieldInfo, newField); } @Override public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + flatVectorWriter.flush(maxDoc, sortMap); for (FieldWriter field : fields) { - long[] quantizedVectorOffsetAndLen = null; - if (field.quantizedWriter != null) { - assert quantizedVectorsWriter != null; - quantizedVectorOffsetAndLen = - quantizedVectorsWriter.flush(sortMap, field.quantizedWriter, field.docsWithField); - } if (sortMap == null) { - writeField(field, maxDoc, quantizedVectorOffsetAndLen); + writeField(field); } else { - writeSortingField(field, maxDoc, sortMap, quantizedVectorOffsetAndLen); + writeSortingField(field, sortMap); } } } @@ -208,40 +148,29 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { throw new IllegalStateException("already finished"); } finished = true; - if (quantizedVectorsWriter != null) { - quantizedVectorsWriter.finish(); - } + flatVectorWriter.finish(); if (meta != null) { // write end of fields marker meta.writeInt(-1); CodecUtil.writeFooter(meta); } - if (vectorData != null) { - CodecUtil.writeFooter(vectorData); + if (vectorIndex != null) { CodecUtil.writeFooter(vectorIndex); } } @Override public long ramBytesUsed() { - long total = 0; + long total = SHALLOW_RAM_BYTES_USED; + total += flatVectorWriter.ramBytesUsed(); for (FieldWriter field : fields) { total += field.ramBytesUsed(); } return total; } - private void writeField(FieldWriter fieldData, int maxDoc, long[] quantizedVecOffsetAndLen) - throws IOException { - // write vector values - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - switch (fieldData.fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectors(fieldData); - case FLOAT32 -> writeFloat32Vectors(fieldData); - } - long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; - + private void writeField(FieldWriter fieldData) throws IOException { // write graph long vectorIndexOffset = vectorIndex.getFilePointer(); OnHeapHnswGraph graph = fieldData.getGraph(); @@ -249,43 +178,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; writeMeta( - fieldData.isQuantized(), fieldData.fieldInfo, - maxDoc, - fieldData.getConfiguredQuantile(), - fieldData.getMinQuantile(), - fieldData.getMaxQuantile(), - quantizedVecOffsetAndLen, - vectorDataOffset, - vectorDataLength, vectorIndexOffset, vectorIndexLength, - fieldData.docsWithField, + fieldData.docsWithField.cardinality(), graph, graphLevelNodeOffsets); } - private void writeFloat32Vectors(FieldWriter fieldData) throws IOException { - final ByteBuffer buffer = - ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (Object v : fieldData.vectors) { - buffer.asFloatBuffer().put((float[]) v); - vectorData.writeBytes(buffer.array(), buffer.array().length); - } - } - - private void writeByteVectors(FieldWriter fieldData) throws IOException { - for (Object v : fieldData.vectors) { - byte[] vector = (byte[]) v; - vectorData.writeBytes(vector, vector.length); - } - } - - private void writeSortingField( - FieldWriter fieldData, - int maxDoc, - Sorter.DocMap sortMap, - long[] quantizedVectorOffsetAndLen) + private void writeSortingField(FieldWriter fieldData, Sorter.DocMap sortMap) throws IOException { final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) @@ -310,15 +211,6 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { } doc++; } - - // write vector values - long vectorDataOffset = - switch (fieldData.fieldInfo.getVectorEncoding()) { - case BYTE -> writeSortedByteVectors(fieldData, ordMap); - case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap); - }; - long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; - // write graph long vectorIndexOffset = vectorIndex.getFilePointer(); OnHeapHnswGraph graph = fieldData.getGraph(); @@ -327,44 +219,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; writeMeta( - fieldData.isQuantized(), fieldData.fieldInfo, - maxDoc, - fieldData.getConfiguredQuantile(), - fieldData.getMinQuantile(), - fieldData.getMaxQuantile(), - quantizedVectorOffsetAndLen, - vectorDataOffset, - vectorDataLength, vectorIndexOffset, vectorIndexLength, - newDocsWithField, + fieldData.docsWithField.cardinality(), mockGraph, graphLevelNodeOffsets); } - private long writeSortedFloat32Vectors(FieldWriter fieldData, int[] ordMap) - throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - final ByteBuffer buffer = - ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int ordinal : ordMap) { - float[] vector = (float[]) fieldData.vectors.get(ordinal); - buffer.asFloatBuffer().put(vector); - vectorData.writeBytes(buffer.array(), buffer.array().length); - } - return vectorDataOffset; - } - - private long writeSortedByteVectors(FieldWriter fieldData, int[] ordMap) throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - for (int ordinal : ordMap) { - byte[] vector = (byte[]) fieldData.vectors.get(ordinal); - vectorData.writeBytes(vector, vector.length); - } - return vectorDataOffset; - } - /** * Reconstructs the graph given the old and new node ids. * @@ -475,116 +337,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { @Override public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES); - IndexOutput tempVectorData = null; - IndexInput vectorDataInput = null; - CloseableRandomVectorScorerSupplier scorerSupplier = null; + CloseableRandomVectorScorerSupplier scorerSupplier = + flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState); boolean success = false; try { - ScalarQuantizer scalarQuantizer = null; - long[] quantizedVectorDataOffsetAndLength = null; - // If we have configured quantization and are FLOAT32 - if (quantizedVectorsWriter != null - && fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - // We need the quantization parameters to write to the meta file - scalarQuantizer = quantizedVectorsWriter.mergeQuantiles(fieldInfo, mergeState); - if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { - segmentWriteState.infoStream.message( - QUANTIZED_VECTOR_COMPONENT, - "Merged quantiles field: " - + fieldInfo.name - + " newly merged quantile: " - + scalarQuantizer); - } - assert scalarQuantizer != null; - quantizedVectorDataOffsetAndLength = new long[2]; - quantizedVectorDataOffsetAndLength[0] = quantizedVectorData.alignFilePointer(Float.BYTES); - scorerSupplier = - quantizedVectorsWriter.mergeOneField( - segmentWriteState, fieldInfo, mergeState, scalarQuantizer); - quantizedVectorDataOffsetAndLength[1] = - quantizedVectorData.getFilePointer() - quantizedVectorDataOffsetAndLength[0]; - } - final DocsWithFieldSet docsWithField; - int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize; - - // If we extract vector storage, this could be cleaner. - // But for now, vector storage & index creation/storage live together. - if (scorerSupplier == null) { - tempVectorData = - segmentWriteState.directory.createTempOutput( - vectorData.getName(), "temp", segmentWriteState.context); - docsWithField = - switch (fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectorData( - tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); - case FLOAT32 -> writeVectorData( - tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); - }; - CodecUtil.writeFooter(tempVectorData); - IOUtils.close(tempVectorData); - // copy the temporary file vectors to the actual data file - vectorDataInput = - segmentWriteState.directory.openInput( - tempVectorData.getName(), segmentWriteState.context); - vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength()); - CodecUtil.retrieveChecksum(vectorDataInput); - final RandomVectorScorerSupplier innerScoreSupplier = - switch (fieldInfo.getVectorEncoding()) { - case BYTE -> RandomVectorScorerSupplier.createBytes( - new OffHeapByteVectorValues.DenseOffHeapVectorValues( - fieldInfo.getVectorDimension(), - docsWithField.cardinality(), - vectorDataInput, - byteSize), - fieldInfo.getVectorSimilarityFunction()); - case FLOAT32 -> RandomVectorScorerSupplier.createFloats( - new OffHeapFloatVectorValues.DenseOffHeapVectorValues( - fieldInfo.getVectorDimension(), - docsWithField.cardinality(), - vectorDataInput, - byteSize), - fieldInfo.getVectorSimilarityFunction()); - }; - final String tempFileName = tempVectorData.getName(); - final IndexInput finalVectorDataInput = vectorDataInput; - scorerSupplier = - new CloseableRandomVectorScorerSupplier() { - boolean closed = false; - - @Override - public RandomVectorScorer scorer(int ord) throws IOException { - return innerScoreSupplier.scorer(ord); - } - - @Override - public void close() throws IOException { - if (closed) { - return; - } - closed = true; - IOUtils.close(finalVectorDataInput); - segmentWriteState.directory.deleteFile(tempFileName); - } - - @Override - public RandomVectorScorerSupplier copy() throws IOException { - // here we just return the inner out since we only need to close this outside copy - return innerScoreSupplier.copy(); - } - }; - } else { - // No need to use temporary file as we don't have to re-open for reading - docsWithField = - switch (fieldInfo.getVectorEncoding()) { - case BYTE -> writeByteVectorData( - vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState)); - case FLOAT32 -> writeVectorData( - vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); - }; - } - - long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; long vectorIndexOffset = vectorIndex.getFilePointer(); // build the graph using the temporary vector data // we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction @@ -592,7 +348,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { // TODO: separate random access vector values from DocIdSetIterator? OnHeapHnswGraph graph = null; int[][] vectorIndexNodeOffsets = null; - if (docsWithField.cardinality() != 0) { + if (scorerSupplier.totalVectorCount() > 0) { // build graph HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier); for (int i = 0; i < mergeState.liveDocs.length; i++) { @@ -608,23 +364,17 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { } graph = merger.merge( - mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality()); + mergedVectorIterator, + segmentWriteState.infoStream, + scorerSupplier.totalVectorCount()); vectorIndexNodeOffsets = writeGraph(graph); } long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; writeMeta( - scalarQuantizer != null, fieldInfo, - segmentWriteState.segmentInfo.maxDoc(), - scalarQuantizer == null ? null : scalarQuantizer.getConfiguredQuantile(), - scalarQuantizer == null ? null : scalarQuantizer.getLowerQuantile(), - scalarQuantizer == null ? null : scalarQuantizer.getUpperQuantile(), - quantizedVectorDataOffsetAndLength, - vectorDataOffset, - vectorDataLength, vectorIndexOffset, vectorIndexLength, - docsWithField, + scorerSupplier.totalVectorCount(), graph, vectorIndexNodeOffsets); success = true; @@ -632,11 +382,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { if (success) { IOUtils.close(scorerSupplier); } else { - IOUtils.closeWhileHandlingException(scorerSupplier, vectorDataInput, tempVectorData); - if (tempVectorData != null) { - IOUtils.deleteFilesIgnoringExceptions( - segmentWriteState.directory, tempVectorData.getName()); - } + IOUtils.closeWhileHandlingException(scorerSupplier); } } } @@ -652,7 +398,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { int countOnLevel0 = graph.size(); int[][] offsets = new int[graph.numLevels()][]; for (int level = 0; level < graph.numLevels(); level++) { - int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level)); + int[] sortedNodes = NodesIterator.getSortedNodes(graph.getNodesOnLevel(level)); offsets[level] = new int[sortedNodes.length]; int nodeOffsetId = 0; for (int node : sortedNodes) { @@ -680,80 +426,21 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { return offsets; } - public static int[] getSortedNodes(NodesIterator nodesOnLevel) { - int[] sortedNodes = new int[nodesOnLevel.size()]; - for (int n = 0; nodesOnLevel.hasNext(); n++) { - sortedNodes[n] = nodesOnLevel.nextInt(); - } - Arrays.sort(sortedNodes); - return sortedNodes; - } - - private HnswGraphMerger createGraphMerger( - FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) { - if (mergeExec != null) { - return new ConcurrentHnswMerger( - fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers); - } - return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth); - } - private void writeMeta( - boolean isQuantized, FieldInfo field, - int maxDoc, - Float configuredQuantizationQuantile, - Float lowerQuantile, - Float upperQuantile, - long[] quantizedVectorDataOffsetAndLen, - long vectorDataOffset, - long vectorDataLength, long vectorIndexOffset, long vectorIndexLength, - DocsWithFieldSet docsWithField, + int count, HnswGraph graph, int[][] graphLevelNodeOffsets) throws IOException { meta.writeInt(field.number); meta.writeInt(field.getVectorEncoding().ordinal()); meta.writeInt(field.getVectorSimilarityFunction().ordinal()); - meta.writeByte(isQuantized ? (byte) 1 : (byte) 0); - if (isQuantized) { - assert lowerQuantile != null - && upperQuantile != null - && quantizedVectorDataOffsetAndLen != null; - assert quantizedVectorDataOffsetAndLen.length == 2; - meta.writeInt( - Float.floatToIntBits( - configuredQuantizationQuantile != null - ? configuredQuantizationQuantile - : calculateDefaultQuantile(field.getVectorDimension()))); - meta.writeInt(Float.floatToIntBits(lowerQuantile)); - meta.writeInt(Float.floatToIntBits(upperQuantile)); - meta.writeVLong(quantizedVectorDataOffsetAndLen[0]); - meta.writeVLong(quantizedVectorDataOffsetAndLen[1]); - } else { - assert configuredQuantizationQuantile == null - && lowerQuantile == null - && upperQuantile == null - && quantizedVectorDataOffsetAndLen == null; - } - meta.writeVLong(vectorDataOffset); - meta.writeVLong(vectorDataLength); meta.writeVLong(vectorIndexOffset); meta.writeVLong(vectorIndexLength); meta.writeVInt(field.getVectorDimension()); - - // write docIDs - int count = docsWithField.cardinality(); meta.writeInt(count); - if (isQuantized) { - OrdToDocDISIReaderConfiguration.writeStoredMeta( - DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField); - } - OrdToDocDISIReaderConfiguration.writeStoredMeta( - DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField); - meta.writeVInt(M); // write graph nodes on each level if (graph == null) { @@ -799,109 +486,47 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { } } - /** - * Writes the byte vector values to the output and returns a set of documents that contains - * vectors. - */ - private static DocsWithFieldSet writeByteVectorData( - IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { - DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { - // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); - assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; - output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(docV); + private HnswGraphMerger createGraphMerger( + FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) { + if (mergeExec != null) { + return new ConcurrentHnswMerger( + fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers); } - return docsWithField; - } - - /** - * Writes the vector values to the output and returns a set of documents that contains vectors. - */ - private static DocsWithFieldSet writeVectorData( - IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { - DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - ByteBuffer buffer = - ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) - .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { - // write vector - float[] value = floatVectorValues.vectorValue(); - buffer.asFloatBuffer().put(value); - output.writeBytes(buffer.array(), buffer.limit()); - docsWithField.add(docV); - } - return docsWithField; + return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth); } @Override public void close() throws IOException { - IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData); + IOUtils.close(meta, vectorIndex, flatVectorWriter); } - private abstract static class FieldWriter extends KnnFieldVectorsWriter { + private static class FieldWriter extends KnnFieldVectorsWriter { + + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class); + private final FieldInfo fieldInfo; - private final int dim; private final DocsWithFieldSet docsWithField; private final List vectors; private final HnswGraphBuilder hnswGraphBuilder; - private final Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter - quantizedWriter; - private int lastDocID = -1; private int node = 0; - static FieldWriter create( - FieldInfo fieldInfo, - int M, - int beamWidth, - InfoStream infoStream, - Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter writer) + static FieldWriter create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { - int dim = fieldInfo.getVectorDimension(); return switch (fieldInfo.getVectorEncoding()) { - case BYTE -> new FieldWriter(fieldInfo, M, beamWidth, infoStream, writer) { - @Override - public byte[] copyValue(byte[] value) { - return ArrayUtil.copyOfSubArray(value, 0, dim); - } - }; - case FLOAT32 -> new FieldWriter(fieldInfo, M, beamWidth, infoStream, writer) { - @Override - public float[] copyValue(float[] value) { - return ArrayUtil.copyOfSubArray(value, 0, dim); - } - }; + case BYTE -> new FieldWriter(fieldInfo, M, beamWidth, infoStream); + case FLOAT32 -> new FieldWriter(fieldInfo, M, beamWidth, infoStream); }; } @SuppressWarnings("unchecked") - FieldWriter( - FieldInfo fieldInfo, - int M, - int beamWidth, - InfoStream infoStream, - Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedWriter) + FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream) throws IOException { this.fieldInfo = fieldInfo; - this.dim = fieldInfo.getVectorDimension(); this.docsWithField = new DocsWithFieldSet(); - this.quantizedWriter = quantizedWriter; vectors = new ArrayList<>(); - if (quantizedWriter != null - && fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { - throw new IllegalArgumentException( - "Vector encoding [" - + VectorEncoding.FLOAT32 - + "] required for quantized vectors; provided=" - + fieldInfo.getVectorEncoding()); - } - RAVectorValues raVectors = new RAVectorValues<>(vectors, dim); + RAVectorValues raVectors = new RAVectorValues<>(vectors, fieldInfo.getVectorDimension()); RandomVectorScorerSupplier scorerSupplier = switch (fieldInfo.getVectorEncoding()) { case BYTE -> RandomVectorScorerSupplier.createBytes( @@ -925,20 +550,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { + "\" appears more than once in this document (only one value is allowed per field)"); } assert docID > lastDocID; - T copy = copyValue(vectorValue); - if (quantizedWriter != null) { - assert vectorValue instanceof float[]; - quantizedWriter.addValue((float[]) copy); - } + vectors.add(vectorValue); docsWithField.add(docID); - vectors.add(copy); hnswGraphBuilder.addGraphNode(node); node++; lastDocID = docID; } + @Override + public T copyValue(T vectorValue) { + throw new UnsupportedOperationException(); + } + OnHeapHnswGraph getGraph() { - if (vectors.size() > 0) { + if (node > 0) { return hnswGraphBuilder.getGraph(); } else { return null; @@ -947,32 +572,11 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter { @Override public long ramBytesUsed() { - if (vectors.size() == 0) return 0; - long quantizationSpace = quantizedWriter != null ? quantizedWriter.ramBytesUsed() : 0L; - return docsWithField.ramBytesUsed() + return SHALLOW_SIZE + + docsWithField.ramBytesUsed() + (long) vectors.size() * (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) - + (long) vectors.size() - * fieldInfo.getVectorDimension() - * fieldInfo.getVectorEncoding().byteSize - + hnswGraphBuilder.getGraph().ramBytesUsed() - + quantizationSpace; - } - - Float getConfiguredQuantile() { - return quantizedWriter == null ? null : quantizedWriter.getQuantile(); - } - - Float getMinQuantile() { - return quantizedWriter == null ? null : quantizedWriter.getMinQuantile(); - } - - Float getMaxQuantile() { - return quantizedWriter == null ? null : quantizedWriter.getMaxQuantile(); - } - - boolean isQuantized() { - return quantizedWriter != null; + + hnswGraphBuilder.getGraph().ramBytesUsed(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java index 424491617e5..f6550a220e3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsFormat.java @@ -17,20 +17,31 @@ package org.apache.lucene.codecs.lucene99; +import java.io.IOException; +import org.apache.lucene.codecs.FlatVectorsFormat; +import org.apache.lucene.codecs.FlatVectorsReader; +import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + /** * Format supporting vector quantization, storage, and retrieval * * @lucene.experimental */ -public final class Lucene99ScalarQuantizedVectorsFormat { +public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat { public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC"; static final String NAME = "Lucene99ScalarQuantizedVectorsFormat"; static final int VERSION_START = 0; static final int VERSION_CURRENT = VERSION_START; - static final String QUANTIZED_VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsData"; - static final String QUANTIZED_VECTOR_DATA_EXTENSION = "veq"; + static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta"; + static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData"; + static final String META_EXTENSION = "vemq"; + static final String VECTOR_DATA_EXTENSION = "veq"; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat(); /** The minimum quantile */ private static final float MINIMUM_QUANTILE = 0.9f; @@ -74,6 +85,24 @@ public final class Lucene99ScalarQuantizedVectorsFormat { @Override public String toString() { - return NAME + "(name=" + NAME + ", quantile=" + quantile + ")"; + return NAME + + "(name=" + + NAME + + ", quantile=" + + quantile + + ", rawVectorFormat=" + + rawVectorFormat + + ")"; + } + + @Override + public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new Lucene99ScalarQuantizedVectorsWriter( + state, quantile, rawVectorFormat.fieldsWriter(state)); + } + + @Override + public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state)); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 68a74ca492a..3984b960be3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -18,46 +18,128 @@ package org.apache.lucene.codecs.lucene99; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatVectorsReader; import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataInput; import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Accountable; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.RamUsageEstimator; +import org.apache.lucene.util.ScalarQuantizer; +import org.apache.lucene.util.hnsw.RandomVectorScorer; /** * Reads Scalar Quantized vectors from the index segments along with index data structures. * * @lucene.experimental */ -public final class Lucene99ScalarQuantizedVectorsReader { +public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReader + implements QuantizedVectorsReader { + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class); + + private final Map fields = new HashMap<>(); private final IndexInput quantizedVectorData; + private final FlatVectorsReader rawVectorsReader; - Lucene99ScalarQuantizedVectorsReader(IndexInput quantizedVectorData) { - this.quantizedVectorData = quantizedVectorData; + Lucene99ScalarQuantizedVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + int versionMeta = -1; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION); + boolean success = false; + try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + meta, + Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene99ScalarQuantizedVectorsFormat.VERSION_START, + Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(meta, state.fieldInfos); + } catch (Throwable exception) { + priorE = exception; + } finally { + try { + CodecUtil.checkFooter(meta, priorE); + success = true; + } finally { + if (success == false) { + IOUtils.close(rawVectorsReader); + } + } + } + } + success = false; + this.rawVectorsReader = rawVectorsReader; + try { + quantizedVectorData = + openDataInput( + state, + versionMeta, + Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION, + Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } } - static void validateFieldEntry( - FieldInfo info, int fieldDimension, int size, long quantizedVectorDataLength) { + private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + FieldInfo info = infos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + FieldEntry fieldEntry = readField(meta); + validateFieldEntry(info, fieldEntry); + fields.put(info.name, fieldEntry); + } + } + + static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) { int dimension = info.getVectorDimension(); - if (dimension != fieldDimension) { + if (dimension != fieldEntry.dimension) { throw new IllegalStateException( "Inconsistent vector dimension for field=\"" + info.name + "\"; " + dimension + " != " - + fieldDimension); + + fieldEntry.dimension); } // int8 quantized and calculated stored offset. long quantizedVectorBytes = dimension + Float.BYTES; - long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, size); - if (numQuantizedVectorBytes != quantizedVectorDataLength) { + long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size); + if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) { throw new IllegalStateException( "Quantized vector data length " - + quantizedVectorDataLength + + fieldEntry.vectorDataLength + " not matching size=" - + size + + fieldEntry.size + " * (dim=" + dimension + " + 4)" @@ -66,23 +148,184 @@ public final class Lucene99ScalarQuantizedVectorsReader { } } + @Override public void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); CodecUtil.checksumEntireFile(quantizedVectorData); } - OffHeapQuantizedByteVectorValues getQuantizedVectorValues( - OrdToDocDISIReaderConfiguration configuration, - int dimension, - int size, - long quantizedVectorDataOffset, - long quantizedVectorDataLength) + @Override + public FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + private static IndexInput openDataInput( + SegmentReadState state, int versionMeta, String fileExtension, String codecName) throws IOException { + String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + IndexInput in = state.directory.openInput(fileName, state.context); + boolean success = false; + try { + int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + Lucene99ScalarQuantizedVectorsFormat.VERSION_START, + Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException { + FieldEntry fieldEntry = fields.get(field); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } + if (fieldEntry.scalarQuantizer == null) { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + OffHeapQuantizedByteVectorValues vectorValues = + OffHeapQuantizedByteVectorValues.load( + fieldEntry.ordToDoc, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, + quantizedVectorData); + return new ScalarQuantizedRandomVectorScorer( + fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target); + } + + @Override + public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException { + return rawVectorsReader.getRandomVectorScorer(field, target); + } + + @Override + public void close() throws IOException { + IOUtils.close(quantizedVectorData, rawVectorsReader); + } + + @Override + public long ramBytesUsed() { + long size = SHALLOW_SIZE; + size += + RamUsageEstimator.sizeOfMap( + fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class)); + size += rawVectorsReader.ramBytesUsed(); + return size; + } + + private FieldEntry readField(IndexInput input) throws IOException { + VectorEncoding vectorEncoding = readVectorEncoding(input); + VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + return new FieldEntry(input, vectorEncoding, similarityFunction); + } + + private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { + int similarityFunctionId = input.readInt(); + if (similarityFunctionId < 0 + || similarityFunctionId >= VectorSimilarityFunction.values().length) { + throw new CorruptIndexException( + "Invalid similarity function id: " + similarityFunctionId, input); + } + return VectorSimilarityFunction.values()[similarityFunctionId]; + } + + private VectorEncoding readVectorEncoding(DataInput input) throws IOException { + int encodingId = input.readInt(); + if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { + throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); + } + return VectorEncoding.values()[encodingId]; + } + + @Override + public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException { + FieldEntry fieldEntry = fields.get(fieldName); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } return OffHeapQuantizedByteVectorValues.load( - configuration, - dimension, - size, - quantizedVectorDataOffset, - quantizedVectorDataLength, + fieldEntry.ordToDoc, + fieldEntry.dimension, + fieldEntry.size, + fieldEntry.vectorDataOffset, + fieldEntry.vectorDataLength, quantizedVectorData); } + + @Override + public ScalarQuantizer getQuantizationState(String fieldName) { + FieldEntry fieldEntry = fields.get(fieldName); + if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) { + return null; + } + return fieldEntry.scalarQuantizer; + } + + private static class FieldEntry implements Accountable { + private static final long SHALLOW_SIZE = + RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class); + final VectorSimilarityFunction similarityFunction; + final VectorEncoding vectorEncoding; + final int dimension; + final long vectorDataOffset; + final long vectorDataLength; + final ScalarQuantizer scalarQuantizer; + final int size; + final OrdToDocDISIReaderConfiguration ordToDoc; + + FieldEntry( + IndexInput input, + VectorEncoding vectorEncoding, + VectorSimilarityFunction similarityFunction) + throws IOException { + this.similarityFunction = similarityFunction; + this.vectorEncoding = vectorEncoding; + vectorDataOffset = input.readVLong(); + vectorDataLength = input.readVLong(); + dimension = input.readVInt(); + size = input.readInt(); + if (size > 0) { + float configuredQuantile = Float.intBitsToFloat(input.readInt()); + float minQuantile = Float.intBitsToFloat(input.readInt()); + float maxQuantile = Float.intBitsToFloat(input.readInt()); + scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, configuredQuantile); + } else { + scalarQuantizer = null; + } + ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size); + } + + @Override + public long ramBytesUsed() { + return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 93bc6b0011a..e74217b9a8f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.lucene99; +import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; @@ -29,13 +30,18 @@ import java.nio.ByteOrder; import java.util.ArrayList; import java.util.List; import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.FlatVectorsWriter; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.DocIDMerger; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -44,7 +50,6 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.Accountable; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.RamUsageEstimator; @@ -59,9 +64,9 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; * * @lucene.experimental */ -public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { +public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWriter { - private static final long BASE_RAM_BYTES_USED = + private static final long SHALLOW_RAM_BYTES_USED = shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class); // Used for determining when merged quantiles shifted too far from individual segment quantiles. @@ -82,67 +87,214 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { // the quantization error) and the condition is sensitive enough to detect all adversarial cases, // such as merging clustered data. private static final float REQUANTIZATION_LIMIT = 0.2f; - private final IndexOutput quantizedVectorData; + private final SegmentWriteState segmentWriteState; + + private final List fields = new ArrayList<>(); + private final IndexOutput meta, quantizedVectorData; private final Float quantile; + private final FlatVectorsWriter rawVectorDelegate; private boolean finished; - Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) { - this.quantile = quantile; - this.quantizedVectorData = quantizedVectorData; - } - - QuantizationFieldVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) { - if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { - throw new IllegalArgumentException( - "Only float32 vector fields are supported for quantization"); - } - float quantile = - this.quantile == null - ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) - : this.quantile; - if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { - infoStream.message( - QUANTIZED_VECTOR_COMPONENT, - "quantizing field=" - + fieldInfo.name - + " dimension=" - + fieldInfo.getVectorDimension() - + " quantile=" - + quantile); - } - return QuantizationFieldVectorWriter.create(fieldInfo, quantile, infoStream); - } - - long[] flush( - Sorter.DocMap sortMap, QuantizationFieldVectorWriter field, DocsWithFieldSet docsWithField) + Lucene99ScalarQuantizedVectorsWriter( + SegmentWriteState state, Float quantile, FlatVectorsWriter rawVectorDelegate) throws IOException { - field.finish(); - return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField); + this.quantile = quantile; + segmentWriteState = state; + String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION); + + String quantizedVectorDataFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION); + this.rawVectorDelegate = rawVectorDelegate; + boolean success = false; + try { + meta = state.directory.createOutput(metaFileName, state.context); + quantizedVectorData = + state.directory.createOutput(quantizedVectorDataFileName, state.context); + + CodecUtil.writeIndexHeader( + meta, + Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME, + Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + CodecUtil.writeIndexHeader( + quantizedVectorData, + Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME, + Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } } - void finish() throws IOException { + @Override + public FlatFieldVectorsWriter addField( + FieldInfo fieldInfo, KnnFieldVectorsWriter indexWriter) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + FieldWriter quantizedWriter = + new FieldWriter(quantile, fieldInfo, segmentWriteState.infoStream, indexWriter); + fields.add(quantizedWriter); + indexWriter = quantizedWriter; + } + return rawVectorDelegate.addField(fieldInfo, indexWriter); + } + + @Override + public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + // Since we know we will not be searching for additional indexing, we can just write the + // the vectors directly to the new segment. + // No need to use temporary file as we don't have to re-open for reading + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState); + MergedQuantizedVectorValues byteVectorValues = + MergedQuantizedVectorValues.mergeQuantizedByteVectorValues( + fieldInfo, mergeState, mergedQuantizationState); + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + DocsWithFieldSet docsWithField = + writeQuantizedVectorData(quantizedVectorData, byteVectorValues); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + quantile, + mergedQuantizationState.getLowerQuantile(), + mergedQuantizationState.getUpperQuantile(), + docsWithField); + } + } + + @Override + public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex( + FieldInfo fieldInfo, MergeState mergeState) throws IOException { + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + // Simply merge the underlying delegate, which just copies the raw vector data to a new + // segment file + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState); + return mergeOneFieldToIndex( + segmentWriteState, fieldInfo, mergeState, mergedQuantizationState); + } + // We only merge the delegate, since the field type isn't float32, quantization wasn't + // supported, so bypass it. + return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState); + } + + @Override + public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter field : fields) { + field.finish(); + if (sortMap == null) { + writeField(field, maxDoc); + } else { + writeSortingField(field, maxDoc, sortMap); + } + } + } + + @Override + public void finish() throws IOException { if (finished) { throw new IllegalStateException("already finished"); } finished = true; + rawVectorDelegate.finish(); + if (meta != null) { + // write end of fields marker + meta.writeInt(-1); + CodecUtil.writeFooter(meta); + } if (quantizedVectorData != null) { CodecUtil.writeFooter(quantizedVectorData); } } - private long[] writeField(QuantizationFieldVectorWriter fieldData) throws IOException { - long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); - writeQuantizedVectors(fieldData); - long quantizedVectorDataLength = - quantizedVectorData.getFilePointer() - quantizedVectorDataOffset; - return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength}; + @Override + public long ramBytesUsed() { + long total = SHALLOW_RAM_BYTES_USED; + for (FieldWriter field : fields) { + total += field.ramBytesUsed(); + } + return total; } - private void writeQuantizedVectors(QuantizationFieldVectorWriter fieldData) throws IOException { + private void writeField(FieldWriter fieldData, int maxDoc) throws IOException { + // write vector values + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); + writeQuantizedVectors(fieldData); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; + + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + vectorDataLength, + quantile, + fieldData.minQuantile, + fieldData.maxQuantile, + fieldData.docsWithField); + } + + private void writeMeta( + FieldInfo field, + int maxDoc, + long vectorDataOffset, + long vectorDataLength, + Float configuredQuantizationQuantile, + Float lowerQuantile, + Float upperQuantile, + DocsWithFieldSet docsWithField) + throws IOException { + meta.writeInt(field.number); + meta.writeInt(field.getVectorEncoding().ordinal()); + meta.writeInt(field.getVectorSimilarityFunction().ordinal()); + meta.writeVLong(vectorDataOffset); + meta.writeVLong(vectorDataLength); + meta.writeVInt(field.getVectorDimension()); + int count = docsWithField.cardinality(); + meta.writeInt(count); + if (count > 0) { + assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile); + meta.writeInt( + Float.floatToIntBits( + configuredQuantizationQuantile != null + ? configuredQuantizationQuantile + : calculateDefaultQuantile(field.getVectorDimension()))); + meta.writeInt(Float.floatToIntBits(lowerQuantile)); + meta.writeInt(Float.floatToIntBits(upperQuantile)); + } + // write docIDs + OrdToDocDISIReaderConfiguration.writeStoredMeta( + DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField); + } + + private void writeQuantizedVectors(FieldWriter fieldData) throws IOException { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); - byte[] vector = new byte[fieldData.dim]; + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; + float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; for (float[] v : fieldData.floatVectors) { if (fieldData.normalize) { System.arraycopy(v, 0, copy, 0, copy.length); @@ -151,7 +303,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { } float offsetCorrection = - scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); quantizedVectorData.writeBytes(vector, vector.length); offsetBuffer.putFloat(offsetCorrection); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); @@ -159,14 +311,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { } } - private long[] writeSortingField( - QuantizationFieldVectorWriter fieldData, - Sorter.DocMap sortMap, - DocsWithFieldSet docsWithField) + private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap) throws IOException { final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) - DocIdSetIterator iterator = docsWithField.iterator(); + DocIdSetIterator iterator = fieldData.docsWithField.iterator(); for (int docID = iterator.nextDoc(); docID != DocIdSetIterator.NO_MORE_DOCS; docID = iterator.nextDoc()) { @@ -175,13 +324,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { } DocsWithFieldSet newDocsWithField = new DocsWithFieldSet(); final int[] ordMap = new int[offset - 1]; // new ord to old ord - final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord int ord = 0; int doc = 0; for (int docIdOffset : docIdOffsets) { if (docIdOffset != 0) { ordMap[ord] = docIdOffset - 1; - oldOrdMap[docIdOffset - 1] = ord; newDocsWithField.add(doc); ord++; } @@ -192,16 +339,22 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); writeSortedQuantizedVectors(fieldData, ordMap); long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset; - - return new long[] {vectorDataOffset, quantizedVectorLength}; + writeMeta( + fieldData.fieldInfo, + maxDoc, + vectorDataOffset, + quantizedVectorLength, + quantile, + fieldData.minQuantile, + fieldData.maxQuantile, + newDocsWithField); } - void writeSortedQuantizedVectors(QuantizationFieldVectorWriter fieldData, int[] ordMap) - throws IOException { + private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); - byte[] vector = new byte[fieldData.dim]; + byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; + float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null; for (int ordinal : ordMap) { float[] v = fieldData.floatVectors.get(ordinal); if (fieldData.normalize) { @@ -209,9 +362,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { VectorUtil.l2normalize(copy); v = copy; } - float offsetCorrection = - scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); + scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction()); quantizedVectorData.writeBytes(vector, vector.length); offsetBuffer.putFloat(offsetCorrection); quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length); @@ -219,10 +371,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { } } - ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { - return null; - } + private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) + throws IOException { + assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32; float quantile = this.quantile == null ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) @@ -230,15 +381,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile); } - ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField( + private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex( SegmentWriteState segmentWriteState, FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer mergedQuantizationState) throws IOException { - if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) { - return null; - } + long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES); IndexOutput tempQuantizedVectorData = segmentWriteState.directory.createTempOutput( quantizedVectorData.getName(), "temp", segmentWriteState.context); @@ -257,7 +406,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { tempQuantizedVectorData.getName(), segmentWriteState.context); quantizedVectorData.copyBytes( quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength()); + long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset; CodecUtil.retrieveChecksum(quantizationDataInput); + float quantile = + this.quantile == null + ? calculateDefaultQuantile(fieldInfo.getVectorDimension()) + : this.quantile; + writeMeta( + fieldInfo, + segmentWriteState.segmentInfo.maxDoc(), + vectorDataOffset, + vectorDataLength, + quantile, + mergedQuantizationState.getLowerQuantile(), + mergedQuantizationState.getUpperQuantile(), + docsWithField); success = true; final IndexInput finalQuantizationDataInput = quantizationDataInput; return new ScalarQuantizedCloseableRandomVectorScorerSupplier( @@ -265,6 +428,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { IOUtils.close(finalQuantizationDataInput); segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName()); }, + docsWithField.cardinality(), new ScalarQuantizedRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), mergedQuantizationState, @@ -427,43 +591,35 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { } @Override - public long ramBytesUsed() { - return BASE_RAM_BYTES_USED; + public void close() throws IOException { + IOUtils.close(meta, quantizedVectorData, rawVectorDelegate); } - static class QuantizationFieldVectorWriter implements Accountable { - private static final long SHALLOW_SIZE = - shallowSizeOfInstance(QuantizationFieldVectorWriter.class); - private final int dim; + static class FieldWriter extends FlatFieldVectorsWriter { + private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class); private final List floatVectors; - private final boolean normalize; - private final VectorSimilarityFunction vectorSimilarityFunction; + private final FieldInfo fieldInfo; private final float quantile; private final InfoStream infoStream; + private final boolean normalize; private float minQuantile = Float.POSITIVE_INFINITY; private float maxQuantile = Float.NEGATIVE_INFINITY; private boolean finished; + private final DocsWithFieldSet docsWithField; - static QuantizationFieldVectorWriter create( - FieldInfo fieldInfo, float quantile, InfoStream infoStream) { - return new QuantizationFieldVectorWriter( - fieldInfo.getVectorDimension(), - quantile, - fieldInfo.getVectorSimilarityFunction(), - infoStream); - } - - QuantizationFieldVectorWriter( - int dim, + @SuppressWarnings("unchecked") + FieldWriter( float quantile, - VectorSimilarityFunction vectorSimilarityFunction, - InfoStream infoStream) { - this.dim = dim; + FieldInfo fieldInfo, + InfoStream infoStream, + KnnFieldVectorsWriter indexWriter) { + super((KnnFieldVectorsWriter) indexWriter); this.quantile = quantile; - this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE; - this.vectorSimilarityFunction = vectorSimilarityFunction; + this.fieldInfo = fieldInfo; + this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE; this.floatVectors = new ArrayList<>(); this.infoStream = infoStream; + this.docsWithField = new DocsWithFieldSet(); } void finish() throws IOException { @@ -475,15 +631,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { return; } ScalarQuantizer quantizer = - ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile); + ScalarQuantizer.fromVectors( + new FloatVectorWrapper( + floatVectors, + fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE), + quantile); minQuantile = quantizer.getLowerQuantile(); maxQuantile = quantizer.getUpperQuantile(); if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) { infoStream.message( QUANTIZED_VECTOR_COMPONENT, "quantized field=" - + " dimension=" - + dim + " quantile=" + quantile + " minQuantile=" @@ -494,24 +652,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { finished = true; } - public void addValue(float[] vectorValue) throws IOException { - floatVectors.add(vectorValue); - } - - float getMinQuantile() { - assert finished; - return minQuantile; - } - - float getMaxQuantile() { - assert finished; - return maxQuantile; - } - - float getQuantile() { - return quantile; - } - ScalarQuantizer createQuantizer() { assert finished; return new ScalarQuantizer(minQuantile, maxQuantile, quantile); @@ -519,8 +659,26 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { @Override public long ramBytesUsed() { - if (floatVectors.size() == 0) return SHALLOW_SIZE; - return SHALLOW_SIZE + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + long size = SHALLOW_SIZE; + if (indexingDelegate != null) { + size += indexingDelegate.ramBytesUsed(); + } + if (floatVectors.size() == 0) return size; + return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF; + } + + @Override + public void addValue(int docID, float[] vectorValue) throws IOException { + docsWithField.add(docID); + floatVectors.add(vectorValue); + if (indexingDelegate != null) { + indexingDelegate.addValue(docID, vectorValue); + } + } + + @Override + public float[] copyValue(float[] vectorValue) { + throw new UnsupportedOperationException(); } } @@ -613,6 +771,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { // Either our quantization parameters are way different than the merged ones // Or we have never been quantized. if (reader == null + || reader.getQuantizationState(fieldInfo.name) == null || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) { sub = new QuantizedByteVectorValueSub( @@ -702,6 +861,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { private final FloatVectorValues values; private final ScalarQuantizer quantizer; private final byte[] quantizedVector; + private final float[] normalizedVector; private float offsetValue = 0f; private final VectorSimilarityFunction vectorSimilarityFunction; @@ -714,6 +874,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { this.quantizer = quantizer; this.quantizedVector = new byte[values.dimension()]; this.vectorSimilarityFunction = vectorSimilarityFunction; + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + this.normalizedVector = new float[values.dimension()]; + } else { + this.normalizedVector = null; + } } @Override @@ -745,8 +910,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { public int nextDoc() throws IOException { int doc = values.nextDoc(); if (doc != NO_MORE_DOCS) { - offsetValue = - quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); + quantize(); } return doc; } @@ -755,10 +919,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { public int advance(int target) throws IOException { int doc = values.advance(target); if (doc != NO_MORE_DOCS) { + quantize(); + } + return doc; + } + + private void quantize() throws IOException { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + offsetValue = + quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction); + } else { offsetValue = quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); } - return doc; } } @@ -767,11 +942,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { private final ScalarQuantizedRandomVectorScorerSupplier supplier; private final Closeable onClose; + private final int numVectors; ScalarQuantizedCloseableRandomVectorScorerSupplier( - Closeable onClose, ScalarQuantizedRandomVectorScorerSupplier supplier) { + Closeable onClose, int numVectors, ScalarQuantizedRandomVectorScorerSupplier supplier) { this.onClose = onClose; this.supplier = supplier; + this.numVectors = numVectors; } @Override @@ -788,6 +965,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { public void close() throws IOException { onClose.close(); } + + @Override + public int totalVectorCount() { + return numVectors; + } } private static final class OffsetCorrectedQuantizedByteVectorValues diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 60b5f8101c7..cf297f9b15c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -98,8 +98,6 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue } } - abstract Bits getAcceptOrds(Bits acceptDocs); - static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { private int doc = -1; @@ -138,7 +136,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } } @@ -196,7 +194,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { return null; } @@ -268,7 +266,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue } @Override - Bits getAcceptOrds(Bits acceptDocs) { + public Bits getAcceptOrds(Bits acceptDocs) { return null; } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorer.java index 424dfbeefd2..f6fc9a1c805 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/ScalarQuantizedRandomVectorScorer.java @@ -25,7 +25,8 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** Quantized vector scorer */ -final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer { +final class ScalarQuantizedRandomVectorScorer + extends RandomVectorScorer.AbstractRandomVectorScorer { private static float quantizeQuery( float[] query, @@ -54,6 +55,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer { RandomAccessQuantizedByteVectorValues values, byte[] query, float queryOffset) { + super(values); this.quantizedQuery = query; this.queryOffset = queryOffset; this.similarity = similarityFunction; @@ -65,6 +67,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer { ScalarQuantizer scalarQuantizer, RandomAccessQuantizedByteVectorValues values, float[] query) { + super(values); byte[] quantizedQuery = new byte[query.length]; float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer); this.quantizedQuery = quantizedQuery; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java index 1490624ced2..148963e7dc9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/CloseableRandomVectorScorerSupplier.java @@ -26,5 +26,6 @@ import java.io.Closeable; *

NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily * closeable */ -public interface CloseableRandomVectorScorerSupplier - extends Closeable, RandomVectorScorerSupplier {} +public interface CloseableRandomVectorScorerSupplier extends Closeable, RandomVectorScorerSupplier { + int totalVectorCount(); +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index d924a8f9cbd..5fc531d9f64 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -18,6 +18,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; +import org.apache.lucene.util.Bits; /** * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based @@ -56,4 +57,14 @@ public interface RandomAccessVectorValues { default int ordToDoc(int ord) { return ord; } + + /** + * Returns the {@link Bits} representing live documents. By default, this is an identity function. + * + * @param acceptDocs the accept docs + * @return the accept docs + */ + default Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index 36b7e331dc9..76c7bb910f7 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.util.Bits; /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ public interface RandomVectorScorer { @@ -30,6 +31,31 @@ public interface RandomVectorScorer { */ float score(int node) throws IOException; + /** + * @return the maximum possible ordinal for this scorer + */ + int maxOrd(); + + /** + * Translates vector ordinal to the correct document ID. By default, this is an identity function. + * + * @param ord the vector ordinal + * @return the document Id for that vector ordinal + */ + default int ordToDoc(int ord) { + return ord; + } + + /** + * Returns the {@link Bits} representing live documents. By default, this is an identity function. + * + * @param acceptDocs the accept docs + * @return the accept docs + */ + default Bits getAcceptOrds(Bits acceptDocs) { + return acceptDocs; + } + /** * Creates a default scorer for float vectors. * @@ -53,7 +79,12 @@ public interface RandomVectorScorer { + " differs from field dimension: " + vectors.dimension()); } - return node -> similarityFunction.compare(query, vectors.vectorValue(node)); + return new AbstractRandomVectorScorer<>(vectors) { + @Override + public float score(int node) throws IOException { + return similarityFunction.compare(query, vectors.vectorValue(node)); + } + }; } /** @@ -79,6 +110,44 @@ public interface RandomVectorScorer { + " differs from field dimension: " + vectors.dimension()); } - return node -> similarityFunction.compare(query, vectors.vectorValue(node)); + return new AbstractRandomVectorScorer<>(vectors) { + @Override + public float score(int node) throws IOException { + return similarityFunction.compare(query, vectors.vectorValue(node)); + } + }; + } + + /** + * Creates a default scorer for random access vectors. + * + * @param the type of the vector values + */ + abstract class AbstractRandomVectorScorer implements RandomVectorScorer { + private final RandomAccessVectorValues values; + + /** + * Creates a new scorer for the given vector values. + * + * @param values the vector values + */ + public AbstractRandomVectorScorer(RandomAccessVectorValues values) { + this.values = values; + } + + @Override + public int maxOrd() { + return values.size(); + } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public Bits getAcceptOrds(Bits acceptDocs) { + return values.getAcceptOrds(acceptDocs); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java index 1db50ee4562..2b809967552 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorerSupplier.java @@ -87,8 +87,12 @@ public interface RandomVectorScorerSupplier { @Override public RandomVectorScorer scorer(int ord) throws IOException { - return cand -> - similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); + return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) { + @Override + public float score(int cand) throws IOException { + return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); + } + }; } @Override @@ -115,8 +119,12 @@ public interface RandomVectorScorerSupplier { @Override public RandomVectorScorer scorer(int ord) throws IOException { - return cand -> - similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); + return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) { + @Override + public float score(int cand) throws IOException { + return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand)); + } + }; } @Override diff --git a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index 311137338fa..ba0de7b2464 100644 --- a/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/lucene/core/src/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -14,3 +14,4 @@ # limitations under the License. org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat +org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index d3d13d5ef61..f8e1fc128db 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -47,10 +47,9 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat return new Lucene99Codec() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene99HnswVectorsFormat( + return new Lucene99HnswScalarQuantizedVectorsFormat( Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN, - Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH, - new Lucene99ScalarQuantizedVectorsFormat()); + Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH); } }; } @@ -145,12 +144,11 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new Lucene99HnswVectorsFormat( - 10, 20, new Lucene99ScalarQuantizedVectorsFormat(0.9f)); + return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null); } }; String expectedString = - "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9))"; + "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))"; assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); } } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java index 8f2fdd2c75f..085a203dad9 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswVectorsFormat.java @@ -37,7 +37,7 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase } }; String expectedString = - "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=none)"; + "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat())"; assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java deleted file mode 100644 index 871eb63a9b2..00000000000 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.codecs.lucene99; - -import java.util.ArrayList; -import java.util.List; -import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.ScalarQuantizer; - -public class TestLucene99ScalarQuantizedVectorsFormat extends LuceneTestCase { - - public void testDefaultQuantile() { - float defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(99); - assertEquals(0.99f, defaultQuantile, 1e-5); - defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(1); - assertEquals(0.9f, defaultQuantile, 1e-5); - defaultQuantile = - Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(Integer.MAX_VALUE - 2); - assertEquals(1.0f, defaultQuantile, 1e-5); - } - - public void testLimits() { - expectThrows( - IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(0.89f)); - expectThrows( - IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(1.1f)); - } - - public void testQuantileMergeWithMissing() { - List quantiles = new ArrayList<>(); - quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f)); - quantiles.add(null); - List segmentSizes = List.of(1, 1, 1, 1); - assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f)); - assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(List.of(), List.of(), 0.1f)); - } - - public void testQuantileMerge() { - List quantiles = new ArrayList<>(); - quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f)); - List segmentSizes = List.of(1, 1, 1); - ScalarQuantizer merged = - Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f); - assertEquals(0.2f, merged.getLowerQuantile(), 1e-5); - assertEquals(0.3f, merged.getUpperQuantile(), 1e-5); - assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5); - } - - public void testQuantileMergeWithDifferentSegmentSizes() { - List quantiles = new ArrayList<>(); - quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f)); - quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f)); - List segmentSizes = List.of(1, 2, 3); - ScalarQuantizer merged = - Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f); - assertEquals(0.2333333f, merged.getLowerQuantile(), 1e-5); - assertEquals(0.3333333f, merged.getUpperQuantile(), 1e-5); - assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5); - } -} diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 32f325970ab..6ea8867a36d 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -32,9 +32,9 @@ import java.util.concurrent.CountDownLatch; import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader; -import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -83,12 +83,7 @@ public class TestKnnGraph extends LuceneTestCase { int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1; similarityFunction = VectorSimilarityFunction.values()[similarity]; vectorEncoding = randomVectorEncoding(); - - Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat = - vectorEncoding.equals(VectorEncoding.FLOAT32) && randomBoolean() - ? new Lucene99ScalarQuantizedVectorsFormat(1f) - : null; - + boolean quantized = randomBoolean(); codec = new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) { @Override @@ -96,14 +91,16 @@ public class TestKnnGraph extends LuceneTestCase { return new PerFieldKnnVectorsFormat() { @Override public KnnVectorsFormat getKnnVectorsFormatForField(String field) { - return new Lucene99HnswVectorsFormat( - M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, scalarQuantizedVectorsFormat); + return quantized + ? new Lucene99HnswScalarQuantizedVectorsFormat( + M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH) + : new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); } }; } }; - if (vectorEncoding == VectorEncoding.FLOAT32 && scalarQuantizedVectorsFormat == null) { + if (vectorEncoding == VectorEncoding.FLOAT32) { float32Codec = codec; } else { float32Codec = diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java index 257c72fb994..17b709d2e66 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestNeighborArray.java @@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.Bits; public class TestNeighborArray extends LuceneTestCase { @@ -190,7 +191,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addOutOfOrder(7, Float.NaN); neighbors.addOutOfOrder(6, Float.NaN); neighbors.addOutOfOrder(4, Float.NaN); - int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1); + int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 1); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); @@ -206,7 +207,7 @@ public class TestNeighborArray extends LuceneTestCase { neighbors.addOutOfOrder(17, Float.NaN); neighbors.addOutOfOrder(16, Float.NaN); neighbors.addOutOfOrder(14, Float.NaN); - int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11); + int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 11); assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked); assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors); assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors); @@ -223,4 +224,21 @@ public class TestNeighborArray extends LuceneTestCase { assertEquals(nodes[i], neighbors.node[i]); } } + + interface TestRandomVectorScorer extends RandomVectorScorer { + @Override + default int maxOrd() { + throw new UnsupportedOperationException(); + } + + @Override + default int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + + @Override + default Bits getAcceptOrds(Bits acceptDocs) { + throw new UnsupportedOperationException(); + } + } }