mirror of https://github.com/apache/lucene.git
Refactoring HNSW to use a new internal FlatVectorFormat (#12729)
Currently the HNSW codec does too many things, it not only indexes vectors, but stores them and determines how to store them given the vector type. This PR extracts out the vector storage into a new format `Lucene99FlatVectorsFormat` and adds new base class called `FlatVectorsFormat`. This allows for some additional helper functions that allow an indexing codec (like HNSW) take advantage of the flat formats. Additionally, this PR refactors the new `Lucene99ScalarQuantizedVectorsFormat` to be a `FlatVectorsFormat`. Now, `Lucene99HnswVectorsFormat` is constructed with a `Lucene99FlatVectorsFormat` and a new `Lucene99HnswScalarQuantizedVectorsFormat` that uses `Lucene99ScalarQuantizedVectorsFormat`
This commit is contained in:
parent
c28d174cd7
commit
a47ba3369f
|
@ -184,6 +184,10 @@ New Features
|
||||||
* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
|
* GITHUB#12660: HNSW graph now can be merged with multiple thread. Configurable in Lucene99HnswVectorsFormat.
|
||||||
(Patrick Zhai)
|
(Patrick Zhai)
|
||||||
|
|
||||||
|
* GITHUB#12729: Add new Lucene99FlatVectorsFormat for writing vectors in a flat format and refactor
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat & Lucene99HnswVectorsFormat to reuse the flat formats.
|
||||||
|
Additionally, this allows flat formats to be pluggable independent of HNSW. (Ben Trent)
|
||||||
|
|
||||||
Improvements
|
Improvements
|
||||||
---------------------
|
---------------------
|
||||||
* GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions
|
* GITHUB#12523: TaskExecutor waits for all tasks to complete before returning when Exceptions
|
||||||
|
|
|
@ -80,8 +80,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
@ -120,7 +118,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return acceptDocs;
|
return acceptDocs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -184,7 +182,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
if (acceptDocs == null) {
|
if (acceptDocs == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -256,7 +254,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,8 +89,6 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
@ -129,7 +127,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return acceptDocs;
|
return acceptDocs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,7 +194,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
if (acceptDocs == null) {
|
if (acceptDocs == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -268,7 +266,7 @@ abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,8 +86,6 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
@ -126,7 +124,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return acceptDocs;
|
return acceptDocs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -193,7 +191,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
if (acceptDocs == null) {
|
if (acceptDocs == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -265,7 +263,7 @@ abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
|
||||||
|
|
||||||
/** Lucene Core. */
|
/** Lucene Core. */
|
||||||
@SuppressWarnings("module") // the test framework is compiled after the core...
|
@SuppressWarnings("module") // the test framework is compiled after the core...
|
||||||
|
@ -70,7 +69,8 @@ module org.apache.lucene.core {
|
||||||
provides org.apache.lucene.codecs.DocValuesFormat with
|
provides org.apache.lucene.codecs.DocValuesFormat with
|
||||||
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||||
provides org.apache.lucene.codecs.KnnVectorsFormat with
|
provides org.apache.lucene.codecs.KnnVectorsFormat with
|
||||||
Lucene99HnswVectorsFormat;
|
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat,
|
||||||
|
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
|
||||||
provides org.apache.lucene.codecs.PostingsFormat with
|
provides org.apache.lucene.codecs.PostingsFormat with
|
||||||
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
|
org.apache.lucene.codecs.lucene99.Lucene99PostingsFormat;
|
||||||
provides org.apache.lucene.index.SortFieldProvider with
|
provides org.apache.lucene.index.SortFieldProvider with
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Vectors' writer for a field
|
||||||
|
*
|
||||||
|
* @param <T> an array type; the type of vectors to be written
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class FlatFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The delegate to write to, can be null When non-null, all vectors seen should be written to the
|
||||||
|
* delegate along with being written to the flat vectors.
|
||||||
|
*/
|
||||||
|
protected final KnnFieldVectorsWriter<T> indexingDelegate;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sole constructor that expects some indexingDelegate. All vectors seen should be written to the
|
||||||
|
* delegate along with being written to the flat vectors.
|
||||||
|
*
|
||||||
|
* @param indexingDelegate the delegate to write to, can be null
|
||||||
|
*/
|
||||||
|
protected FlatFieldVectorsWriter(KnnFieldVectorsWriter<T> indexingDelegate) {
|
||||||
|
this.indexingDelegate = indexingDelegate;
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,39 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Encodes/decodes per-document vectors
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class FlatVectorsFormat {
|
||||||
|
|
||||||
|
/** Sole constructor */
|
||||||
|
protected FlatVectorsFormat() {}
|
||||||
|
|
||||||
|
/** Returns a {@link FlatVectorsWriter} to write the vectors to the index. */
|
||||||
|
public abstract FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException;
|
||||||
|
|
||||||
|
/** Returns a {@link KnnVectorsReader} to read the vectors from the index. */
|
||||||
|
public abstract FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException;
|
||||||
|
}
|
|
@ -0,0 +1,92 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads vectors from an index. When searching this reader, it iterates every vector in the index
|
||||||
|
* and scores them
|
||||||
|
*
|
||||||
|
* <p>This class is useful when:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>the number of vectors is small
|
||||||
|
* <li>when used along side some additional indexing structure that can be used to better search
|
||||||
|
* the vectors (like HNSW).
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class FlatVectorsReader implements Closeable, Accountable {
|
||||||
|
|
||||||
|
/** Sole constructor */
|
||||||
|
protected FlatVectorsReader() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a {@link RandomVectorScorer} for the given field and target vector.
|
||||||
|
*
|
||||||
|
* @param field the field to search
|
||||||
|
* @param target the target vector
|
||||||
|
* @return a {@link RandomVectorScorer} for the given field and target vector.
|
||||||
|
* @throws IOException if an I/O error occurs when reading from the index.
|
||||||
|
*/
|
||||||
|
public abstract RandomVectorScorer getRandomVectorScorer(String field, float[] target)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns a {@link RandomVectorScorer} for the given field and target vector.
|
||||||
|
*
|
||||||
|
* @param field the field to search
|
||||||
|
* @param target the target vector
|
||||||
|
* @return a {@link RandomVectorScorer} for the given field and target vector.
|
||||||
|
* @throws IOException if an I/O error occurs when reading from the index.
|
||||||
|
*/
|
||||||
|
public abstract RandomVectorScorer getRandomVectorScorer(String field, byte[] target)
|
||||||
|
throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks consistency of this reader.
|
||||||
|
*
|
||||||
|
* <p>Note that this may be costly in terms of I/O, e.g. may involve computing a checksum value
|
||||||
|
* against large data files.
|
||||||
|
*
|
||||||
|
* @lucene.internal
|
||||||
|
*/
|
||||||
|
public abstract void checkIntegrity() throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link FloatVectorValues} for the given {@code field}. The behavior is undefined if
|
||||||
|
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||||
|
* never {@code null}.
|
||||||
|
*/
|
||||||
|
public abstract FloatVectorValues getFloatVectorValues(String field) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link ByteVectorValues} for the given {@code field}. The behavior is undefined if
|
||||||
|
* the given field doesn't have KNN vectors enabled on its {@link FieldInfo}. The return value is
|
||||||
|
* never {@code null}.
|
||||||
|
*/
|
||||||
|
public abstract ByteVectorValues getByteVectorValues(String field) throws IOException;
|
||||||
|
}
|
|
@ -0,0 +1,74 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.MergeState;
|
||||||
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Vectors' writer for a field that allows additional indexing logic to be implemented by the caller
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public abstract class FlatVectorsWriter implements Accountable, Closeable {
|
||||||
|
|
||||||
|
/** Sole constructor */
|
||||||
|
protected FlatVectorsWriter() {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Add a new field for indexing, allowing the user to provide a writer that the flat vectors
|
||||||
|
* writer can delegate to if additional indexing logic is required.
|
||||||
|
*
|
||||||
|
* @param fieldInfo fieldInfo of the field to add
|
||||||
|
* @param indexWriter the writer to delegate to, can be null
|
||||||
|
* @return a writer for the field
|
||||||
|
* @throws IOException if an I/O error occurs when adding the field
|
||||||
|
*/
|
||||||
|
public abstract FlatFieldVectorsWriter<?> addField(
|
||||||
|
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Write the field for merging, providing a scorer over the newly merged flat vectors. This way
|
||||||
|
* any additional merging logic can be implemented by the user of this class.
|
||||||
|
*
|
||||||
|
* @param fieldInfo fieldInfo of the field to merge
|
||||||
|
* @param mergeState mergeState of the segments to merge
|
||||||
|
* @return a scorer over the newly merged flat vectors, which should be closed as it holds a
|
||||||
|
* temporary file handle to read over the newly merged vectors
|
||||||
|
* @throws IOException if an I/O error occurs when merging
|
||||||
|
*/
|
||||||
|
public abstract CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||||
|
FieldInfo fieldInfo, MergeState mergeState) throws IOException;
|
||||||
|
|
||||||
|
/** Write field for merging */
|
||||||
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
IOUtils.close(mergeOneFieldToIndex(fieldInfo, mergeState));
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Called once at the end before close */
|
||||||
|
public abstract void finish() throws IOException;
|
||||||
|
|
||||||
|
/** Flush all buffered data on disk * */
|
||||||
|
public abstract void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException;
|
||||||
|
}
|
|
@ -94,8 +94,6 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
||||||
* vector.
|
* vector.
|
||||||
|
|
|
@ -88,8 +88,6 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
||||||
* vector.
|
* vector.
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lucene 9.9 flat vector format, which encodes numeric vector values
|
||||||
|
*
|
||||||
|
* <h2>.vec (vector data) file</h2>
|
||||||
|
*
|
||||||
|
* <p>For each field:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
|
||||||
|
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
|
||||||
|
* sample is stored as an IEEE float in little-endian byte order.
|
||||||
|
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
||||||
|
* note that only in sparse case
|
||||||
|
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||||
|
* that only in sparse case
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* <h2>.vemf (vector metadata) file</h2>
|
||||||
|
*
|
||||||
|
* <p>For each field:
|
||||||
|
*
|
||||||
|
* <ul>
|
||||||
|
* <li><b>[int32]</b> field number
|
||||||
|
* <li><b>[int32]</b> vector similarity function ordinal
|
||||||
|
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
|
||||||
|
* <li><b>[vlong]</b> length of this field's vectors, in bytes
|
||||||
|
* <li><b>[vint]</b> dimension of this field's vectors
|
||||||
|
* <li><b>[int]</b> the number of documents having values for this field
|
||||||
|
* <li><b>[int8]</b> if equals to -1, dense – all documents have values for a field. If equals to
|
||||||
|
* 0, sparse – some documents missing values.
|
||||||
|
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
|
||||||
|
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||||
|
* that only in sparse case
|
||||||
|
* </ul>
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class Lucene99FlatVectorsFormat extends FlatVectorsFormat {
|
||||||
|
|
||||||
|
static final String META_CODEC_NAME = "Lucene99FlatVectorsFormatMeta";
|
||||||
|
static final String VECTOR_DATA_CODEC_NAME = "Lucene99FlatVectorsFormatData";
|
||||||
|
static final String META_EXTENSION = "vemf";
|
||||||
|
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||||
|
|
||||||
|
public static final int VERSION_START = 0;
|
||||||
|
public static final int VERSION_CURRENT = VERSION_START;
|
||||||
|
|
||||||
|
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
|
||||||
|
|
||||||
|
/** Constructs a format */
|
||||||
|
public Lucene99FlatVectorsFormat() {
|
||||||
|
super();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new Lucene99FlatVectorsWriter(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new Lucene99FlatVectorsReader(state);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Lucene99FlatVectorsFormat()";
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,333 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
import org.apache.lucene.store.DataInput;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Reads vectors from the index segments.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class Lucene99FlatVectorsReader extends FlatVectorsReader {
|
||||||
|
|
||||||
|
private static final long SHALLOW_SIZE =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsFormat.class);
|
||||||
|
|
||||||
|
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||||
|
private final IndexInput vectorData;
|
||||||
|
|
||||||
|
Lucene99FlatVectorsReader(SegmentReadState state) throws IOException {
|
||||||
|
int versionMeta = readMetadata(state);
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
vectorData =
|
||||||
|
openDataInput(
|
||||||
|
state,
|
||||||
|
versionMeta,
|
||||||
|
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION,
|
||||||
|
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private int readMetadata(SegmentReadState state) throws IOException {
|
||||||
|
String metaFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
|
||||||
|
int versionMeta = -1;
|
||||||
|
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
|
||||||
|
Throwable priorE = null;
|
||||||
|
try {
|
||||||
|
versionMeta =
|
||||||
|
CodecUtil.checkIndexHeader(
|
||||||
|
meta,
|
||||||
|
Lucene99FlatVectorsFormat.META_CODEC_NAME,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_START,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
readFields(meta, state.fieldInfos);
|
||||||
|
} catch (Throwable exception) {
|
||||||
|
priorE = exception;
|
||||||
|
} finally {
|
||||||
|
CodecUtil.checkFooter(meta, priorE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return versionMeta;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IndexInput openDataInput(
|
||||||
|
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
|
||||||
|
throws IOException {
|
||||||
|
String fileName =
|
||||||
|
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||||
|
IndexInput in = state.directory.openInput(fileName, state.context);
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
int versionVectorData =
|
||||||
|
CodecUtil.checkIndexHeader(
|
||||||
|
in,
|
||||||
|
codecName,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_START,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
if (versionMeta != versionVectorData) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"Format versions mismatch: meta="
|
||||||
|
+ versionMeta
|
||||||
|
+ ", "
|
||||||
|
+ codecName
|
||||||
|
+ "="
|
||||||
|
+ versionVectorData,
|
||||||
|
in);
|
||||||
|
}
|
||||||
|
CodecUtil.retrieveChecksum(in);
|
||||||
|
success = true;
|
||||||
|
return in;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
||||||
|
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||||
|
FieldInfo info = infos.fieldInfo(fieldNumber);
|
||||||
|
if (info == null) {
|
||||||
|
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||||
|
}
|
||||||
|
FieldEntry fieldEntry = readField(meta);
|
||||||
|
validateFieldEntry(info, fieldEntry);
|
||||||
|
fields.put(info.name, fieldEntry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
|
||||||
|
int dimension = info.getVectorDimension();
|
||||||
|
if (dimension != fieldEntry.dimension) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Inconsistent vector dimension for field=\""
|
||||||
|
+ info.name
|
||||||
|
+ "\"; "
|
||||||
|
+ dimension
|
||||||
|
+ " != "
|
||||||
|
+ fieldEntry.dimension);
|
||||||
|
}
|
||||||
|
|
||||||
|
int byteSize =
|
||||||
|
switch (info.getVectorEncoding()) {
|
||||||
|
case BYTE -> Byte.BYTES;
|
||||||
|
case FLOAT32 -> Float.BYTES;
|
||||||
|
};
|
||||||
|
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
|
||||||
|
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
|
||||||
|
if (numBytes != fieldEntry.vectorDataLength) {
|
||||||
|
throw new IllegalStateException(
|
||||||
|
"Vector data length "
|
||||||
|
+ fieldEntry.vectorDataLength
|
||||||
|
+ " not matching size="
|
||||||
|
+ fieldEntry.size
|
||||||
|
+ " * dim="
|
||||||
|
+ dimension
|
||||||
|
+ " * byteSize="
|
||||||
|
+ byteSize
|
||||||
|
+ " = "
|
||||||
|
+ numBytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||||
|
int similarityFunctionId = input.readInt();
|
||||||
|
if (similarityFunctionId < 0
|
||||||
|
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"Invalid similarity function id: " + similarityFunctionId, input);
|
||||||
|
}
|
||||||
|
return VectorSimilarityFunction.values()[similarityFunctionId];
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||||
|
int encodingId = input.readInt();
|
||||||
|
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||||
|
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||||
|
}
|
||||||
|
return VectorEncoding.values()[encodingId];
|
||||||
|
}
|
||||||
|
|
||||||
|
private FieldEntry readField(IndexInput input) throws IOException {
|
||||||
|
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||||
|
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||||
|
return new FieldEntry(input, vectorEncoding, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
return Lucene99FlatVectorsReader.SHALLOW_SIZE
|
||||||
|
+ RamUsageEstimator.sizeOfMap(
|
||||||
|
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void checkIntegrity() throws IOException {
|
||||||
|
CodecUtil.checksumEntireFile(vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
|
}
|
||||||
|
return OffHeapFloatVectorValues.load(
|
||||||
|
fieldEntry.ordToDoc,
|
||||||
|
fieldEntry.vectorEncoding,
|
||||||
|
fieldEntry.dimension,
|
||||||
|
fieldEntry.vectorDataOffset,
|
||||||
|
fieldEntry.vectorDataLength,
|
||||||
|
vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"field=\""
|
||||||
|
+ field
|
||||||
|
+ "\" is encoded as: "
|
||||||
|
+ fieldEntry.vectorEncoding
|
||||||
|
+ " expected: "
|
||||||
|
+ VectorEncoding.FLOAT32);
|
||||||
|
}
|
||||||
|
return OffHeapByteVectorValues.load(
|
||||||
|
fieldEntry.ordToDoc,
|
||||||
|
fieldEntry.vectorEncoding,
|
||||||
|
fieldEntry.dimension,
|
||||||
|
fieldEntry.vectorDataOffset,
|
||||||
|
fieldEntry.vectorDataLength,
|
||||||
|
vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return RandomVectorScorer.createFloats(
|
||||||
|
OffHeapFloatVectorValues.load(
|
||||||
|
fieldEntry.ordToDoc,
|
||||||
|
fieldEntry.vectorEncoding,
|
||||||
|
fieldEntry.dimension,
|
||||||
|
fieldEntry.vectorDataOffset,
|
||||||
|
fieldEntry.vectorDataLength,
|
||||||
|
vectorData),
|
||||||
|
fieldEntry.similarityFunction,
|
||||||
|
target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return RandomVectorScorer.createBytes(
|
||||||
|
OffHeapByteVectorValues.load(
|
||||||
|
fieldEntry.ordToDoc,
|
||||||
|
fieldEntry.vectorEncoding,
|
||||||
|
fieldEntry.dimension,
|
||||||
|
fieldEntry.vectorDataOffset,
|
||||||
|
fieldEntry.vectorDataLength,
|
||||||
|
vectorData),
|
||||||
|
fieldEntry.similarityFunction,
|
||||||
|
target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
IOUtils.close(vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class FieldEntry implements Accountable {
|
||||||
|
private static final long SHALLOW_SIZE =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
|
||||||
|
final VectorSimilarityFunction similarityFunction;
|
||||||
|
final VectorEncoding vectorEncoding;
|
||||||
|
final int dimension;
|
||||||
|
final long vectorDataOffset;
|
||||||
|
final long vectorDataLength;
|
||||||
|
final int size;
|
||||||
|
final OrdToDocDISIReaderConfiguration ordToDoc;
|
||||||
|
|
||||||
|
FieldEntry(
|
||||||
|
IndexInput input,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction)
|
||||||
|
throws IOException {
|
||||||
|
this.similarityFunction = similarityFunction;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
|
vectorDataOffset = input.readVLong();
|
||||||
|
vectorDataLength = input.readVLong();
|
||||||
|
dimension = input.readVInt();
|
||||||
|
size = input.readInt();
|
||||||
|
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,508 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
|
||||||
|
import java.io.Closeable;
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.nio.ByteBuffer;
|
||||||
|
import java.nio.ByteOrder;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.MergeState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.index.Sorter;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.store.IndexOutput;
|
||||||
|
import org.apache.lucene.util.ArrayUtil;
|
||||||
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes vector values to index segments.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class Lucene99FlatVectorsWriter extends FlatVectorsWriter {
|
||||||
|
|
||||||
|
private static final long SHALLLOW_RAM_BYTES_USED =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99FlatVectorsWriter.class);
|
||||||
|
|
||||||
|
private final SegmentWriteState segmentWriteState;
|
||||||
|
private final IndexOutput meta, vectorData;
|
||||||
|
|
||||||
|
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||||
|
private boolean finished;
|
||||||
|
|
||||||
|
Lucene99FlatVectorsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
segmentWriteState = state;
|
||||||
|
String metaFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name, state.segmentSuffix, Lucene99FlatVectorsFormat.META_EXTENSION);
|
||||||
|
|
||||||
|
String vectorDataFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
Lucene99FlatVectorsFormat.VECTOR_DATA_EXTENSION);
|
||||||
|
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
meta = state.directory.createOutput(metaFileName, state.context);
|
||||||
|
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
|
||||||
|
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
meta,
|
||||||
|
Lucene99FlatVectorsFormat.META_CODEC_NAME,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
vectorData,
|
||||||
|
Lucene99FlatVectorsFormat.VECTOR_DATA_CODEC_NAME,
|
||||||
|
Lucene99FlatVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatFieldVectorsWriter<?> addField(
|
||||||
|
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
|
||||||
|
FieldWriter<?> newField = FieldWriter.create(fieldInfo, indexWriter);
|
||||||
|
fields.add(newField);
|
||||||
|
return newField;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
|
for (FieldWriter<?> field : fields) {
|
||||||
|
if (sortMap == null) {
|
||||||
|
writeField(field, maxDoc);
|
||||||
|
} else {
|
||||||
|
writeSortingField(field, maxDoc, sortMap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void finish() throws IOException {
|
||||||
|
if (finished) {
|
||||||
|
throw new IllegalStateException("already finished");
|
||||||
|
}
|
||||||
|
finished = true;
|
||||||
|
if (meta != null) {
|
||||||
|
// write end of fields marker
|
||||||
|
meta.writeInt(-1);
|
||||||
|
CodecUtil.writeFooter(meta);
|
||||||
|
}
|
||||||
|
if (vectorData != null) {
|
||||||
|
CodecUtil.writeFooter(vectorData);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
long total = SHALLLOW_RAM_BYTES_USED;
|
||||||
|
for (FieldWriter<?> field : fields) {
|
||||||
|
total += field.ramBytesUsed();
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeField(FieldWriter<?> fieldData, int maxDoc) throws IOException {
|
||||||
|
// write vector values
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeByteVectors(fieldData);
|
||||||
|
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
||||||
|
}
|
||||||
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
|
writeMeta(
|
||||||
|
fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, fieldData.docsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||||
|
final ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
for (Object v : fieldData.vectors) {
|
||||||
|
buffer.asFloatBuffer().put((float[]) v);
|
||||||
|
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||||
|
for (Object v : fieldData.vectors) {
|
||||||
|
byte[] vector = (byte[]) v;
|
||||||
|
vectorData.writeBytes(vector, vector.length);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeSortingField(FieldWriter<?> fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||||
|
throws IOException {
|
||||||
|
final int[] docIdOffsets = new int[sortMap.size()];
|
||||||
|
int offset = 1; // 0 means no vector for this (field, document)
|
||||||
|
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||||
|
for (int docID = iterator.nextDoc();
|
||||||
|
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
docID = iterator.nextDoc()) {
|
||||||
|
int newDocID = sortMap.oldToNew(docID);
|
||||||
|
docIdOffsets[newDocID] = offset++;
|
||||||
|
}
|
||||||
|
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||||
|
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
||||||
|
int ord = 0;
|
||||||
|
int doc = 0;
|
||||||
|
for (int docIdOffset : docIdOffsets) {
|
||||||
|
if (docIdOffset != 0) {
|
||||||
|
ordMap[ord] = docIdOffset - 1;
|
||||||
|
newDocsWithField.add(doc);
|
||||||
|
ord++;
|
||||||
|
}
|
||||||
|
doc++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// write vector values
|
||||||
|
long vectorDataOffset =
|
||||||
|
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
|
||||||
|
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
|
||||||
|
};
|
||||||
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
|
writeMeta(fieldData.fieldInfo, maxDoc, vectorDataOffset, vectorDataLength, newDocsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
|
||||||
|
throws IOException {
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
final ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
for (int ordinal : ordMap) {
|
||||||
|
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||||
|
buffer.asFloatBuffer().put(vector);
|
||||||
|
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||||
|
}
|
||||||
|
return vectorDataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
for (int ordinal : ordMap) {
|
||||||
|
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||||
|
vectorData.writeBytes(vector, vector.length);
|
||||||
|
}
|
||||||
|
return vectorDataOffset;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
// Since we know we will not be searching for additional indexing, we can just write the
|
||||||
|
// the vectors directly to the new segment.
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
// No need to use temporary file as we don't have to re-open for reading
|
||||||
|
DocsWithFieldSet docsWithField =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeByteVectorData(
|
||||||
|
vectorData,
|
||||||
|
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||||
|
case FLOAT32 -> writeVectorData(
|
||||||
|
vectorData,
|
||||||
|
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||||
|
};
|
||||||
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
writeMeta(
|
||||||
|
fieldInfo,
|
||||||
|
segmentWriteState.segmentInfo.maxDoc(),
|
||||||
|
vectorDataOffset,
|
||||||
|
vectorDataLength,
|
||||||
|
docsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||||
|
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||||
|
IndexOutput tempVectorData =
|
||||||
|
segmentWriteState.directory.createTempOutput(
|
||||||
|
vectorData.getName(), "temp", segmentWriteState.context);
|
||||||
|
IndexInput vectorDataInput = null;
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
// write the vector data to a temporary file
|
||||||
|
DocsWithFieldSet docsWithField =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> writeByteVectorData(
|
||||||
|
tempVectorData,
|
||||||
|
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||||
|
case FLOAT32 -> writeVectorData(
|
||||||
|
tempVectorData,
|
||||||
|
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||||
|
};
|
||||||
|
CodecUtil.writeFooter(tempVectorData);
|
||||||
|
IOUtils.close(tempVectorData);
|
||||||
|
|
||||||
|
// copy the temporary file vectors to the actual data file
|
||||||
|
vectorDataInput =
|
||||||
|
segmentWriteState.directory.openInput(
|
||||||
|
tempVectorData.getName(), segmentWriteState.context);
|
||||||
|
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
||||||
|
CodecUtil.retrieveChecksum(vectorDataInput);
|
||||||
|
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
writeMeta(
|
||||||
|
fieldInfo,
|
||||||
|
segmentWriteState.segmentInfo.maxDoc(),
|
||||||
|
vectorDataOffset,
|
||||||
|
vectorDataLength,
|
||||||
|
docsWithField);
|
||||||
|
success = true;
|
||||||
|
final IndexInput finalVectorDataInput = vectorDataInput;
|
||||||
|
final RandomVectorScorerSupplier randomVectorScorerSupplier =
|
||||||
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
||||||
|
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
finalVectorDataInput,
|
||||||
|
fieldInfo.getVectorDimension() * Byte.BYTES),
|
||||||
|
fieldInfo.getVectorSimilarityFunction());
|
||||||
|
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
|
||||||
|
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||||
|
fieldInfo.getVectorDimension(),
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
finalVectorDataInput,
|
||||||
|
fieldInfo.getVectorDimension() * Float.BYTES),
|
||||||
|
fieldInfo.getVectorSimilarityFunction());
|
||||||
|
};
|
||||||
|
return new FlatCloseableRandomVectorScorerSupplier(
|
||||||
|
() -> {
|
||||||
|
IOUtils.close(finalVectorDataInput);
|
||||||
|
segmentWriteState.directory.deleteFile(tempVectorData.getName());
|
||||||
|
},
|
||||||
|
docsWithField.cardinality(),
|
||||||
|
randomVectorScorerSupplier);
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(vectorDataInput, tempVectorData);
|
||||||
|
IOUtils.deleteFilesIgnoringExceptions(
|
||||||
|
segmentWriteState.directory, tempVectorData.getName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeMeta(
|
||||||
|
FieldInfo field,
|
||||||
|
int maxDoc,
|
||||||
|
long vectorDataOffset,
|
||||||
|
long vectorDataLength,
|
||||||
|
DocsWithFieldSet docsWithField)
|
||||||
|
throws IOException {
|
||||||
|
meta.writeInt(field.number);
|
||||||
|
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||||
|
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||||
|
meta.writeVLong(vectorDataOffset);
|
||||||
|
meta.writeVLong(vectorDataLength);
|
||||||
|
meta.writeVInt(field.getVectorDimension());
|
||||||
|
|
||||||
|
// write docIDs
|
||||||
|
int count = docsWithField.cardinality();
|
||||||
|
meta.writeInt(count);
|
||||||
|
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
||||||
|
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||||
|
* vectors.
|
||||||
|
*/
|
||||||
|
private static DocsWithFieldSet writeByteVectorData(
|
||||||
|
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||||
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
|
for (int docV = byteVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = byteVectorValues.nextDoc()) {
|
||||||
|
// write vector
|
||||||
|
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||||
|
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||||
|
output.writeBytes(binaryValue, binaryValue.length);
|
||||||
|
docsWithField.add(docV);
|
||||||
|
}
|
||||||
|
return docsWithField;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||||
|
*/
|
||||||
|
private static DocsWithFieldSet writeVectorData(
|
||||||
|
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
|
||||||
|
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||||
|
ByteBuffer buffer =
|
||||||
|
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||||
|
.order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
for (int docV = floatVectorValues.nextDoc();
|
||||||
|
docV != NO_MORE_DOCS;
|
||||||
|
docV = floatVectorValues.nextDoc()) {
|
||||||
|
// write vector
|
||||||
|
float[] value = floatVectorValues.vectorValue();
|
||||||
|
buffer.asFloatBuffer().put(value);
|
||||||
|
output.writeBytes(buffer.array(), buffer.limit());
|
||||||
|
docsWithField.add(docV);
|
||||||
|
}
|
||||||
|
return docsWithField;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
IOUtils.close(meta, vectorData);
|
||||||
|
}
|
||||||
|
|
||||||
|
private abstract static class FieldWriter<T> extends FlatFieldVectorsWriter<T> {
|
||||||
|
private static final long SHALLOW_RAM_BYTES_USED =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
|
||||||
|
private final FieldInfo fieldInfo;
|
||||||
|
private final int dim;
|
||||||
|
private final DocsWithFieldSet docsWithField;
|
||||||
|
private final List<T> vectors;
|
||||||
|
|
||||||
|
private int lastDocID = -1;
|
||||||
|
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
static FieldWriter<?> create(FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) {
|
||||||
|
int dim = fieldInfo.getVectorDimension();
|
||||||
|
return switch (fieldInfo.getVectorEncoding()) {
|
||||||
|
case BYTE -> new Lucene99FlatVectorsWriter.FieldWriter<>(
|
||||||
|
fieldInfo, (KnnFieldVectorsWriter<byte[]>) indexWriter) {
|
||||||
|
@Override
|
||||||
|
public byte[] copyValue(byte[] value) {
|
||||||
|
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
case FLOAT32 -> new Lucene99FlatVectorsWriter.FieldWriter<>(
|
||||||
|
fieldInfo, (KnnFieldVectorsWriter<float[]>) indexWriter) {
|
||||||
|
@Override
|
||||||
|
public float[] copyValue(float[] value) {
|
||||||
|
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
FieldWriter(FieldInfo fieldInfo, KnnFieldVectorsWriter<T> indexWriter) {
|
||||||
|
super(indexWriter);
|
||||||
|
this.fieldInfo = fieldInfo;
|
||||||
|
this.dim = fieldInfo.getVectorDimension();
|
||||||
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
|
vectors = new ArrayList<>();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addValue(int docID, T vectorValue) throws IOException {
|
||||||
|
if (docID == lastDocID) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"VectorValuesField \""
|
||||||
|
+ fieldInfo.name
|
||||||
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
|
}
|
||||||
|
assert docID > lastDocID;
|
||||||
|
T copy = copyValue(vectorValue);
|
||||||
|
docsWithField.add(docID);
|
||||||
|
vectors.add(copy);
|
||||||
|
lastDocID = docID;
|
||||||
|
if (indexingDelegate != null) {
|
||||||
|
indexingDelegate.addValue(docID, copy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
long size = SHALLOW_RAM_BYTES_USED;
|
||||||
|
if (indexingDelegate != null) {
|
||||||
|
size += indexingDelegate.ramBytesUsed();
|
||||||
|
}
|
||||||
|
if (vectors.size() == 0) return size;
|
||||||
|
return size
|
||||||
|
+ docsWithField.ramBytesUsed()
|
||||||
|
+ (long) vectors.size()
|
||||||
|
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||||
|
+ (long) vectors.size()
|
||||||
|
* fieldInfo.getVectorDimension()
|
||||||
|
* fieldInfo.getVectorEncoding().byteSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static final class FlatCloseableRandomVectorScorerSupplier
|
||||||
|
implements CloseableRandomVectorScorerSupplier {
|
||||||
|
|
||||||
|
private final RandomVectorScorerSupplier supplier;
|
||||||
|
private final Closeable onClose;
|
||||||
|
private final int numVectors;
|
||||||
|
|
||||||
|
FlatCloseableRandomVectorScorerSupplier(
|
||||||
|
Closeable onClose, int numVectors, RandomVectorScorerSupplier supplier) {
|
||||||
|
this.onClose = onClose;
|
||||||
|
this.supplier = supplier;
|
||||||
|
this.numVectors = numVectors;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||||
|
return supplier.scorer(ord);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorerSupplier copy() throws IOException {
|
||||||
|
return supplier.copy();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
onClose.close();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int totalVectorCount() {
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,159 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH;
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN;
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DEFAULT_NUM_MERGE_WORKER;
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_BEAM_WIDTH;
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.MAXIMUM_MAX_CONN;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting
|
||||||
|
* the documents having values. The graph is used to power HNSW search. The format consists of two
|
||||||
|
* files, and uses {@link Lucene99ScalarQuantizedVectorsFormat} to store the actual vectors: For
|
||||||
|
* details on graph storage and file extensions, see {@link Lucene99HnswVectorsFormat}.
|
||||||
|
*
|
||||||
|
* @lucene.experimental
|
||||||
|
*/
|
||||||
|
public final class Lucene99HnswScalarQuantizedVectorsFormat extends KnnVectorsFormat {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
|
||||||
|
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
|
||||||
|
*/
|
||||||
|
private final int maxConn;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The number of candidate neighbors to track while searching the graph for each newly inserted
|
||||||
|
* node. Defaults to to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
|
||||||
|
* HnswGraph} for details.
|
||||||
|
*/
|
||||||
|
private final int beamWidth;
|
||||||
|
|
||||||
|
/** The format for storing, reading, merging vectors on disk */
|
||||||
|
private final FlatVectorsFormat flatVectorsFormat;
|
||||||
|
|
||||||
|
private final int numMergeWorkers;
|
||||||
|
private final ExecutorService mergeExec;
|
||||||
|
|
||||||
|
/** Constructs a format using default graph construction parameters */
|
||||||
|
public Lucene99HnswScalarQuantizedVectorsFormat() {
|
||||||
|
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a format using the given graph construction parameters.
|
||||||
|
*
|
||||||
|
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
||||||
|
* @param beamWidth the size of the queue maintained during graph construction.
|
||||||
|
*/
|
||||||
|
public Lucene99HnswScalarQuantizedVectorsFormat(int maxConn, int beamWidth) {
|
||||||
|
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null, null);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Constructs a format using the given graph construction parameters and scalar quantization.
|
||||||
|
*
|
||||||
|
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
||||||
|
* @param beamWidth the size of the queue maintained during graph construction.
|
||||||
|
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
|
||||||
|
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
|
||||||
|
* @param configuredQuantile the quantile for scalar quantizing the vectors, when `null` it is
|
||||||
|
* calculated based on the vector field dimensions.
|
||||||
|
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
|
||||||
|
* generated by this format to do the merge
|
||||||
|
*/
|
||||||
|
public Lucene99HnswScalarQuantizedVectorsFormat(
|
||||||
|
int maxConn,
|
||||||
|
int beamWidth,
|
||||||
|
int numMergeWorkers,
|
||||||
|
Float configuredQuantile,
|
||||||
|
ExecutorService mergeExec) {
|
||||||
|
super("Lucene99HnswScalarQuantizedVectorsFormat");
|
||||||
|
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"maxConn must be positive and less than or equal to"
|
||||||
|
+ MAXIMUM_MAX_CONN
|
||||||
|
+ "; maxConn="
|
||||||
|
+ maxConn);
|
||||||
|
}
|
||||||
|
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"beamWidth must be positive and less than or equal to"
|
||||||
|
+ MAXIMUM_BEAM_WIDTH
|
||||||
|
+ "; beamWidth="
|
||||||
|
+ beamWidth);
|
||||||
|
}
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.beamWidth = beamWidth;
|
||||||
|
if (numMergeWorkers > 1 && mergeExec == null) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
|
||||||
|
}
|
||||||
|
if (numMergeWorkers == 1 && mergeExec != null) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"No executor service is needed as we'll use single thread to merge");
|
||||||
|
}
|
||||||
|
this.numMergeWorkers = numMergeWorkers;
|
||||||
|
this.mergeExec = mergeExec;
|
||||||
|
this.flatVectorsFormat = new Lucene99ScalarQuantizedVectorsFormat(configuredQuantile);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new Lucene99HnswVectorsWriter(
|
||||||
|
state,
|
||||||
|
maxConn,
|
||||||
|
beamWidth,
|
||||||
|
flatVectorsFormat.fieldsWriter(state),
|
||||||
|
numMergeWorkers,
|
||||||
|
mergeExec);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int getMaxDimensions(String fieldName) {
|
||||||
|
return 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return "Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn="
|
||||||
|
+ maxConn
|
||||||
|
+ ", beamWidth="
|
||||||
|
+ beamWidth
|
||||||
|
+ ", flatVectorFormat="
|
||||||
|
+ flatVectorsFormat
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsFormat;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
@ -30,23 +31,9 @@ import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Lucene 9.9 vector format, which encodes numeric vector values and an optional associated graph
|
* Lucene 9.9 vector format, which encodes numeric vector values into an associated graph connecting
|
||||||
* connecting the documents having values. The graph is used to power HNSW search. The format
|
* the documents having values. The graph is used to power HNSW search. The format consists of two
|
||||||
* consists of three files, with an optional fourth file:
|
* files, and requires a {@link FlatVectorsFormat} to store the actual vectors:
|
||||||
*
|
|
||||||
* <h2>.vec (vector data) file</h2>
|
|
||||||
*
|
|
||||||
* <p>For each field:
|
|
||||||
*
|
|
||||||
* <ul>
|
|
||||||
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
|
|
||||||
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
|
|
||||||
* sample is stored as an IEEE float in little-endian byte order.
|
|
||||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
|
||||||
* note that only in sparse case
|
|
||||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
|
||||||
* that only in sparse case
|
|
||||||
* </ul>
|
|
||||||
*
|
*
|
||||||
* <h2>.vex (vector index)</h2>
|
* <h2>.vex (vector index)</h2>
|
||||||
*
|
*
|
||||||
|
@ -74,14 +61,6 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
* <ul>
|
* <ul>
|
||||||
* <li><b>[int32]</b> field number
|
* <li><b>[int32]</b> field number
|
||||||
* <li><b>[int32]</b> vector similarity function ordinal
|
* <li><b>[int32]</b> vector similarity function ordinal
|
||||||
* <li><b>[byte]</b> if equals to 1 indicates if the field is for quantized vectors
|
|
||||||
* <li><b>[int32]</b> if quantized: the configured quantile float int bits.
|
|
||||||
* <li><b>[int32]</b> if quantized: the calculated lower quantile float int32 bits.
|
|
||||||
* <li><b>[int32]</b> if quantized: the calculated upper quantile float int32 bits.
|
|
||||||
* <li><b>[vlong]</b> if quantized: offset to this field's vectors in the .veq file
|
|
||||||
* <li><b>[vlong]</b> if quantized: length of this field's vectors, in bytes in the .veq file
|
|
||||||
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
|
|
||||||
* <li><b>[vlong]</b> length of this field's vectors, in bytes
|
|
||||||
* <li><b>[vlong]</b> offset to this field's index in the .vex file
|
* <li><b>[vlong]</b> offset to this field's index in the .vex file
|
||||||
* <li><b>[vlong]</b> length of this field's index data, in bytes
|
* <li><b>[vlong]</b> length of this field's index data, in bytes
|
||||||
* <li><b>[vint]</b> dimension of this field's vectors
|
* <li><b>[vint]</b> dimension of this field's vectors
|
||||||
|
@ -101,29 +80,13 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
* </ul>
|
* </ul>
|
||||||
* </ul>
|
* </ul>
|
||||||
*
|
*
|
||||||
* <h2>.veq (quantized vector data) file</h2>
|
|
||||||
*
|
|
||||||
* <p>For each field:
|
|
||||||
*
|
|
||||||
* <ul>
|
|
||||||
* <li>Vector data ordered by field, document ordinal, and vector dimension. Each vector dimension
|
|
||||||
* is stored as a single byte and every vector has a single float32 value for scoring
|
|
||||||
* corrections.
|
|
||||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
|
||||||
* note that only in sparse case
|
|
||||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
|
||||||
* that only in sparse case
|
|
||||||
* </ul>
|
|
||||||
*
|
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
|
|
||||||
static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
|
static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
|
||||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene99HnswVectorsFormatData";
|
|
||||||
static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
|
static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
|
||||||
static final String META_EXTENSION = "vem";
|
static final String META_EXTENSION = "vem";
|
||||||
static final String VECTOR_DATA_EXTENSION = "vec";
|
|
||||||
static final String VECTOR_INDEX_EXTENSION = "vex";
|
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||||
|
|
||||||
public static final int VERSION_START = 0;
|
public static final int VERSION_START = 0;
|
||||||
|
@ -135,7 +98,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
|
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
|
||||||
* numbers here will use an inordinate amount of heap
|
* numbers here will use an inordinate amount of heap
|
||||||
*/
|
*/
|
||||||
private static final int MAXIMUM_MAX_CONN = 512;
|
static final int MAXIMUM_MAX_CONN = 512;
|
||||||
|
|
||||||
/** Default number of maximum connections per node */
|
/** Default number of maximum connections per node */
|
||||||
public static final int DEFAULT_MAX_CONN = 16;
|
public static final int DEFAULT_MAX_CONN = 16;
|
||||||
|
@ -145,7 +108,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
|
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
|
||||||
* 3200`
|
* 3200`
|
||||||
*/
|
*/
|
||||||
private static final int MAXIMUM_BEAM_WIDTH = 3200;
|
static final int MAXIMUM_BEAM_WIDTH = 3200;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Default number of the size of the queue maintained while searching during a graph construction.
|
* Default number of the size of the queue maintained while searching during a graph construction.
|
||||||
|
@ -170,20 +133,15 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
*/
|
*/
|
||||||
private final int beamWidth;
|
private final int beamWidth;
|
||||||
|
|
||||||
/** Should this codec scalar quantize float32 vectors and use this format */
|
/** The format for storing, reading, merging vectors on disk */
|
||||||
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
|
private static final FlatVectorsFormat flatVectorsFormat = new Lucene99FlatVectorsFormat();
|
||||||
|
|
||||||
private final int numMergeWorkers;
|
private final int numMergeWorkers;
|
||||||
private final ExecutorService mergeExec;
|
private final ExecutorService mergeExec;
|
||||||
|
|
||||||
/** Constructs a format using default graph construction parameters */
|
/** Constructs a format using default graph construction parameters */
|
||||||
public Lucene99HnswVectorsFormat() {
|
public Lucene99HnswVectorsFormat() {
|
||||||
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
|
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, DEFAULT_NUM_MERGE_WORKER, null);
|
||||||
}
|
|
||||||
|
|
||||||
public Lucene99HnswVectorsFormat(
|
|
||||||
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
|
|
||||||
this(maxConn, beamWidth, scalarQuantize, DEFAULT_NUM_MERGE_WORKER, null);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -193,7 +151,7 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
* @param beamWidth the size of the queue maintained during graph construction.
|
* @param beamWidth the size of the queue maintained during graph construction.
|
||||||
*/
|
*/
|
||||||
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
|
public Lucene99HnswVectorsFormat(int maxConn, int beamWidth) {
|
||||||
this(maxConn, beamWidth, null);
|
this(maxConn, beamWidth, DEFAULT_NUM_MERGE_WORKER, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -201,18 +159,13 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
*
|
*
|
||||||
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
||||||
* @param beamWidth the size of the queue maintained during graph construction.
|
* @param beamWidth the size of the queue maintained during graph construction.
|
||||||
* @param scalarQuantize the scalar quantization format
|
|
||||||
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
|
* @param numMergeWorkers number of workers (threads) that will be used when doing merge. If
|
||||||
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
|
* larger than 1, a non-null {@link ExecutorService} must be passed as mergeExec
|
||||||
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
|
* @param mergeExec the {@link ExecutorService} that will be used by ALL vector writers that are
|
||||||
* generated by this format to do the merge
|
* generated by this format to do the merge
|
||||||
*/
|
*/
|
||||||
public Lucene99HnswVectorsFormat(
|
public Lucene99HnswVectorsFormat(
|
||||||
int maxConn,
|
int maxConn, int beamWidth, int numMergeWorkers, ExecutorService mergeExec) {
|
||||||
int beamWidth,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat scalarQuantize,
|
|
||||||
int numMergeWorkers,
|
|
||||||
ExecutorService mergeExec) {
|
|
||||||
super("Lucene99HnswVectorsFormat");
|
super("Lucene99HnswVectorsFormat");
|
||||||
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
|
@ -228,6 +181,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
+ "; beamWidth="
|
+ "; beamWidth="
|
||||||
+ beamWidth);
|
+ beamWidth);
|
||||||
}
|
}
|
||||||
|
this.maxConn = maxConn;
|
||||||
|
this.beamWidth = beamWidth;
|
||||||
if (numMergeWorkers > 1 && mergeExec == null) {
|
if (numMergeWorkers > 1 && mergeExec == null) {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
|
"No executor service passed in when " + numMergeWorkers + " merge workers are requested");
|
||||||
|
@ -236,9 +191,6 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
throw new IllegalArgumentException(
|
throw new IllegalArgumentException(
|
||||||
"No executor service is needed as we'll use single thread to merge");
|
"No executor service is needed as we'll use single thread to merge");
|
||||||
}
|
}
|
||||||
this.maxConn = maxConn;
|
|
||||||
this.beamWidth = beamWidth;
|
|
||||||
this.scalarQuantizedVectorsFormat = scalarQuantize;
|
|
||||||
this.numMergeWorkers = numMergeWorkers;
|
this.numMergeWorkers = numMergeWorkers;
|
||||||
this.mergeExec = mergeExec;
|
this.mergeExec = mergeExec;
|
||||||
}
|
}
|
||||||
|
@ -246,12 +198,17 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
return new Lucene99HnswVectorsWriter(
|
return new Lucene99HnswVectorsWriter(
|
||||||
state, maxConn, beamWidth, scalarQuantizedVectorsFormat, numMergeWorkers, mergeExec);
|
state,
|
||||||
|
maxConn,
|
||||||
|
beamWidth,
|
||||||
|
flatVectorsFormat.fieldsWriter(state),
|
||||||
|
numMergeWorkers,
|
||||||
|
mergeExec);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
return new Lucene99HnswVectorsReader(state);
|
return new Lucene99HnswVectorsReader(state, flatVectorsFormat.fieldsReader(state));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -265,8 +222,8 @@ public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||||
+ maxConn
|
+ maxConn
|
||||||
+ ", beamWidth="
|
+ ", beamWidth="
|
||||||
+ beamWidth
|
+ beamWidth
|
||||||
+ ", quantizer="
|
+ ", flatVectorFormat="
|
||||||
+ (scalarQuantizedVectorsFormat == null ? "none" : scalarQuantizedVectorsFormat.toString())
|
+ flatVectorsFormat
|
||||||
+ ")";
|
+ ")";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,11 +24,9 @@ import java.util.Arrays;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsReader;
|
||||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
|
||||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
|
||||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
|
||||||
import org.apache.lucene.index.ByteVectorValues;
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
import org.apache.lucene.index.CorruptIndexException;
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
@ -67,49 +65,14 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|
|
||||||
private final FieldInfos fieldInfos;
|
private final FieldInfos fieldInfos;
|
||||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||||
private final IndexInput vectorData;
|
|
||||||
private final IndexInput vectorIndex;
|
private final IndexInput vectorIndex;
|
||||||
private final IndexInput quantizedVectorData;
|
private final FlatVectorsReader flatVectorsReader;
|
||||||
private final Lucene99ScalarQuantizedVectorsReader quantizedVectorsReader;
|
|
||||||
|
|
||||||
Lucene99HnswVectorsReader(SegmentReadState state) throws IOException {
|
Lucene99HnswVectorsReader(SegmentReadState state, FlatVectorsReader flatVectorsReader)
|
||||||
this.fieldInfos = state.fieldInfos;
|
throws IOException {
|
||||||
int versionMeta = readMetadata(state);
|
this.flatVectorsReader = flatVectorsReader;
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
this.fieldInfos = state.fieldInfos;
|
||||||
vectorData =
|
|
||||||
openDataInput(
|
|
||||||
state,
|
|
||||||
versionMeta,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME);
|
|
||||||
vectorIndex =
|
|
||||||
openDataInput(
|
|
||||||
state,
|
|
||||||
versionMeta,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
|
|
||||||
if (fields.values().stream().anyMatch(FieldEntry::hasQuantizedVectors)) {
|
|
||||||
quantizedVectorData =
|
|
||||||
openDataInput(
|
|
||||||
state,
|
|
||||||
versionMeta,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME);
|
|
||||||
quantizedVectorsReader = new Lucene99ScalarQuantizedVectorsReader(quantizedVectorData);
|
|
||||||
} else {
|
|
||||||
quantizedVectorData = null;
|
|
||||||
quantizedVectorsReader = null;
|
|
||||||
}
|
|
||||||
success = true;
|
|
||||||
} finally {
|
|
||||||
if (success == false) {
|
|
||||||
IOUtils.closeWhileHandlingException(this);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private int readMetadata(SegmentReadState state) throws IOException {
|
|
||||||
String metaFileName =
|
String metaFileName =
|
||||||
IndexFileNames.segmentFileName(
|
IndexFileNames.segmentFileName(
|
||||||
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
||||||
|
@ -129,10 +92,30 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
} catch (Throwable exception) {
|
} catch (Throwable exception) {
|
||||||
priorE = exception;
|
priorE = exception;
|
||||||
} finally {
|
} finally {
|
||||||
CodecUtil.checkFooter(meta, priorE);
|
try {
|
||||||
|
CodecUtil.checkFooter(meta, priorE);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.close(flatVectorsReader);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
success = false;
|
||||||
|
try {
|
||||||
|
vectorIndex =
|
||||||
|
openDataInput(
|
||||||
|
state,
|
||||||
|
versionMeta,
|
||||||
|
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
|
||||||
|
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return versionMeta;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private static IndexInput openDataInput(
|
private static IndexInput openDataInput(
|
||||||
|
@ -194,31 +177,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
+ " != "
|
+ " != "
|
||||||
+ fieldEntry.dimension);
|
+ fieldEntry.dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
int byteSize =
|
|
||||||
switch (info.getVectorEncoding()) {
|
|
||||||
case BYTE -> Byte.BYTES;
|
|
||||||
case FLOAT32 -> Float.BYTES;
|
|
||||||
};
|
|
||||||
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
|
|
||||||
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
|
|
||||||
if (numBytes != fieldEntry.vectorDataLength) {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
"Vector data length "
|
|
||||||
+ fieldEntry.vectorDataLength
|
|
||||||
+ " not matching size="
|
|
||||||
+ fieldEntry.size
|
|
||||||
+ " * dim="
|
|
||||||
+ dimension
|
|
||||||
+ " * byteSize="
|
|
||||||
+ byteSize
|
|
||||||
+ " = "
|
|
||||||
+ numBytes);
|
|
||||||
}
|
|
||||||
if (fieldEntry.hasQuantizedVectors()) {
|
|
||||||
Lucene99ScalarQuantizedVectorsReader.validateFieldEntry(
|
|
||||||
info, fieldEntry.dimension, fieldEntry.size, fieldEntry.quantizedVectorDataLength);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||||
|
@ -249,58 +207,24 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
||||||
+ RamUsageEstimator.sizeOfMap(
|
+ RamUsageEstimator.sizeOfMap(
|
||||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class))
|
||||||
|
+ flatVectorsReader.ramBytesUsed();
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
CodecUtil.checksumEntireFile(vectorData);
|
flatVectorsReader.checkIntegrity();
|
||||||
CodecUtil.checksumEntireFile(vectorIndex);
|
CodecUtil.checksumEntireFile(vectorIndex);
|
||||||
if (quantizedVectorsReader != null) {
|
|
||||||
quantizedVectorsReader.checkIntegrity();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
return flatVectorsReader.getFloatVectorValues(field);
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"field=\""
|
|
||||||
+ field
|
|
||||||
+ "\" is encoded as: "
|
|
||||||
+ fieldEntry.vectorEncoding
|
|
||||||
+ " expected: "
|
|
||||||
+ VectorEncoding.FLOAT32);
|
|
||||||
}
|
|
||||||
return OffHeapFloatVectorValues.load(
|
|
||||||
fieldEntry.ordToDoc,
|
|
||||||
fieldEntry.vectorEncoding,
|
|
||||||
fieldEntry.dimension,
|
|
||||||
fieldEntry.vectorDataOffset,
|
|
||||||
fieldEntry.vectorDataLength,
|
|
||||||
vectorData);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
return flatVectorsReader.getByteVectorValues(field);
|
||||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"field=\""
|
|
||||||
+ field
|
|
||||||
+ "\" is encoded as: "
|
|
||||||
+ fieldEntry.vectorEncoding
|
|
||||||
+ " expected: "
|
|
||||||
+ VectorEncoding.FLOAT32);
|
|
||||||
}
|
|
||||||
return OffHeapByteVectorValues.load(
|
|
||||||
fieldEntry.ordToDoc,
|
|
||||||
fieldEntry.vectorEncoding,
|
|
||||||
fieldEntry.dimension,
|
|
||||||
fieldEntry.vectorDataOffset,
|
|
||||||
fieldEntry.vectorDataLength,
|
|
||||||
vectorData);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -313,42 +237,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (fieldEntry.hasQuantizedVectors()) {
|
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
|
||||||
OffHeapQuantizedByteVectorValues vectorValues =
|
HnswGraphSearcher.search(
|
||||||
quantizedVectorsReader.getQuantizedVectorValues(
|
scorer,
|
||||||
fieldEntry.quantizedOrdToDoc,
|
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
|
||||||
fieldEntry.dimension,
|
getGraph(fieldEntry),
|
||||||
fieldEntry.size,
|
scorer.getAcceptOrds(acceptDocs));
|
||||||
fieldEntry.quantizedVectorDataOffset,
|
|
||||||
fieldEntry.quantizedVectorDataLength);
|
|
||||||
if (vectorValues == null) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
RandomVectorScorer scorer =
|
|
||||||
new ScalarQuantizedRandomVectorScorer(
|
|
||||||
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
|
|
||||||
HnswGraphSearcher.search(
|
|
||||||
scorer,
|
|
||||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
|
||||||
getGraph(fieldEntry),
|
|
||||||
vectorValues.getAcceptOrds(acceptDocs));
|
|
||||||
} else {
|
|
||||||
OffHeapFloatVectorValues vectorValues =
|
|
||||||
OffHeapFloatVectorValues.load(
|
|
||||||
fieldEntry.ordToDoc,
|
|
||||||
fieldEntry.vectorEncoding,
|
|
||||||
fieldEntry.dimension,
|
|
||||||
fieldEntry.vectorDataOffset,
|
|
||||||
fieldEntry.vectorDataLength,
|
|
||||||
vectorData);
|
|
||||||
RandomVectorScorer scorer =
|
|
||||||
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
|
|
||||||
HnswGraphSearcher.search(
|
|
||||||
scorer,
|
|
||||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
|
||||||
getGraph(fieldEntry),
|
|
||||||
vectorValues.getAcceptOrds(acceptDocs));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -361,22 +255,12 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
RandomVectorScorer scorer = flatVectorsReader.getRandomVectorScorer(field, target);
|
||||||
OffHeapByteVectorValues vectorValues =
|
|
||||||
OffHeapByteVectorValues.load(
|
|
||||||
fieldEntry.ordToDoc,
|
|
||||||
fieldEntry.vectorEncoding,
|
|
||||||
fieldEntry.dimension,
|
|
||||||
fieldEntry.vectorDataOffset,
|
|
||||||
fieldEntry.vectorDataLength,
|
|
||||||
vectorData);
|
|
||||||
RandomVectorScorer scorer =
|
|
||||||
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
|
|
||||||
HnswGraphSearcher.search(
|
HnswGraphSearcher.search(
|
||||||
scorer,
|
scorer,
|
||||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
new OrdinalTranslatedKnnCollector(knnCollector, scorer::ordToDoc),
|
||||||
getGraph(fieldEntry),
|
getGraph(fieldEntry),
|
||||||
vectorValues.getAcceptOrds(acceptDocs));
|
scorer.getAcceptOrds(acceptDocs));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -399,32 +283,23 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
IOUtils.close(vectorData, vectorIndex, quantizedVectorData);
|
IOUtils.close(flatVectorsReader, vectorIndex);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public OffHeapQuantizedByteVectorValues getQuantizedVectorValues(String field)
|
public QuantizedByteVectorValues getQuantizedVectorValues(String field) throws IOException {
|
||||||
throws IOException {
|
if (flatVectorsReader instanceof QuantizedVectorsReader) {
|
||||||
FieldEntry fieldEntry = fields.get(field);
|
return ((QuantizedVectorsReader) flatVectorsReader).getQuantizedVectorValues(field);
|
||||||
if (fieldEntry == null || fieldEntry.hasQuantizedVectors() == false) {
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
assert quantizedVectorsReader != null && fieldEntry.quantizedOrdToDoc != null;
|
return null;
|
||||||
return quantizedVectorsReader.getQuantizedVectorValues(
|
|
||||||
fieldEntry.quantizedOrdToDoc,
|
|
||||||
fieldEntry.dimension,
|
|
||||||
fieldEntry.size,
|
|
||||||
fieldEntry.quantizedVectorDataOffset,
|
|
||||||
fieldEntry.quantizedVectorDataLength);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public ScalarQuantizer getQuantizationState(String fieldName) {
|
public ScalarQuantizer getQuantizationState(String field) {
|
||||||
FieldEntry field = fields.get(fieldName);
|
if (flatVectorsReader instanceof QuantizedVectorsReader) {
|
||||||
if (field == null || field.hasQuantizedVectors() == false) {
|
return ((QuantizedVectorsReader) flatVectorsReader).getQuantizationState(field);
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
return field.scalarQuantizer;
|
return null;
|
||||||
}
|
}
|
||||||
|
|
||||||
static class FieldEntry implements Accountable {
|
static class FieldEntry implements Accountable {
|
||||||
|
@ -432,8 +307,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
|
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
|
||||||
final VectorSimilarityFunction similarityFunction;
|
final VectorSimilarityFunction similarityFunction;
|
||||||
final VectorEncoding vectorEncoding;
|
final VectorEncoding vectorEncoding;
|
||||||
final long vectorDataOffset;
|
|
||||||
final long vectorDataLength;
|
|
||||||
final long vectorIndexOffset;
|
final long vectorIndexOffset;
|
||||||
final long vectorIndexLength;
|
final long vectorIndexLength;
|
||||||
final int M;
|
final int M;
|
||||||
|
@ -446,13 +319,6 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
final long offsetsOffset;
|
final long offsetsOffset;
|
||||||
final int offsetsBlockShift;
|
final int offsetsBlockShift;
|
||||||
final long offsetsLength;
|
final long offsetsLength;
|
||||||
final OrdToDocDISIReaderConfiguration ordToDoc;
|
|
||||||
|
|
||||||
final float configuredQuantile, lowerQuantile, upperQuantile;
|
|
||||||
final long quantizedVectorDataOffset, quantizedVectorDataLength;
|
|
||||||
final ScalarQuantizer scalarQuantizer;
|
|
||||||
final boolean isQuantized;
|
|
||||||
final OrdToDocDISIReaderConfiguration quantizedOrdToDoc;
|
|
||||||
|
|
||||||
FieldEntry(
|
FieldEntry(
|
||||||
IndexInput input,
|
IndexInput input,
|
||||||
|
@ -461,36 +327,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.similarityFunction = similarityFunction;
|
this.similarityFunction = similarityFunction;
|
||||||
this.vectorEncoding = vectorEncoding;
|
this.vectorEncoding = vectorEncoding;
|
||||||
this.isQuantized = input.readByte() == 1;
|
|
||||||
// Has int8 quantization
|
|
||||||
if (isQuantized) {
|
|
||||||
configuredQuantile = Float.intBitsToFloat(input.readInt());
|
|
||||||
lowerQuantile = Float.intBitsToFloat(input.readInt());
|
|
||||||
upperQuantile = Float.intBitsToFloat(input.readInt());
|
|
||||||
quantizedVectorDataOffset = input.readVLong();
|
|
||||||
quantizedVectorDataLength = input.readVLong();
|
|
||||||
scalarQuantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, configuredQuantile);
|
|
||||||
} else {
|
|
||||||
configuredQuantile = -1;
|
|
||||||
lowerQuantile = -1;
|
|
||||||
upperQuantile = -1;
|
|
||||||
quantizedVectorDataOffset = -1;
|
|
||||||
quantizedVectorDataLength = -1;
|
|
||||||
scalarQuantizer = null;
|
|
||||||
}
|
|
||||||
vectorDataOffset = input.readVLong();
|
|
||||||
vectorDataLength = input.readVLong();
|
|
||||||
vectorIndexOffset = input.readVLong();
|
vectorIndexOffset = input.readVLong();
|
||||||
vectorIndexLength = input.readVLong();
|
vectorIndexLength = input.readVLong();
|
||||||
dimension = input.readVInt();
|
dimension = input.readVInt();
|
||||||
size = input.readInt();
|
size = input.readInt();
|
||||||
if (isQuantized) {
|
|
||||||
quantizedOrdToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
|
||||||
} else {
|
|
||||||
quantizedOrdToDoc = null;
|
|
||||||
}
|
|
||||||
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
|
||||||
|
|
||||||
// read nodes by level
|
// read nodes by level
|
||||||
M = input.readVInt();
|
M = input.readVInt();
|
||||||
numLevels = input.readVInt();
|
numLevels = input.readVInt();
|
||||||
|
@ -526,16 +366,10 @@ public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean hasQuantizedVectors() {
|
|
||||||
return isQuantized;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
return SHALLOW_SIZE
|
return SHALLOW_SIZE
|
||||||
+ Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum()
|
+ Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum()
|
||||||
+ RamUsageEstimator.sizeOf(ordToDoc)
|
|
||||||
+ (quantizedOrdToDoc == null ? 0 : RamUsageEstimator.sizeOf(quantizedOrdToDoc))
|
|
||||||
+ RamUsageEstimator.sizeOf(offsetsMeta);
|
+ RamUsageEstimator.sizeOf(offsetsMeta);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,40 +18,27 @@
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
|
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import java.nio.ByteBuffer;
|
|
||||||
import java.nio.ByteOrder;
|
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.concurrent.ExecutorService;
|
import java.util.concurrent.ExecutorService;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
|
||||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
|
||||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
|
||||||
import org.apache.lucene.index.ByteVectorValues;
|
|
||||||
import org.apache.lucene.index.DocsWithFieldSet;
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
|
||||||
import org.apache.lucene.index.IndexFileNames;
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
import org.apache.lucene.index.VectorEncoding;
|
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.store.IndexInput;
|
|
||||||
import org.apache.lucene.store.IndexOutput;
|
import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.ArrayUtil;
|
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.InfoStream;
|
import org.apache.lucene.util.InfoStream;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
import org.apache.lucene.util.ScalarQuantizer;
|
|
||||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||||
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
|
import org.apache.lucene.util.hnsw.ConcurrentHnswMerger;
|
||||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||||
|
@ -62,7 +49,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||||
|
|
||||||
|
@ -73,11 +59,13 @@ import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||||
*/
|
*/
|
||||||
public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
|
private static final long SHALLOW_RAM_BYTES_USED =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsWriter.class);
|
||||||
private final SegmentWriteState segmentWriteState;
|
private final SegmentWriteState segmentWriteState;
|
||||||
private final IndexOutput meta, vectorData, quantizedVectorData, vectorIndex;
|
private final IndexOutput meta, vectorIndex;
|
||||||
private final int M;
|
private final int M;
|
||||||
private final int beamWidth;
|
private final int beamWidth;
|
||||||
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
|
private final FlatVectorsWriter flatVectorWriter;
|
||||||
private final int numMergeWorkers;
|
private final int numMergeWorkers;
|
||||||
private final ExecutorService mergeExec;
|
private final ExecutorService mergeExec;
|
||||||
|
|
||||||
|
@ -88,42 +76,30 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
SegmentWriteState state,
|
SegmentWriteState state,
|
||||||
int M,
|
int M,
|
||||||
int beamWidth,
|
int beamWidth,
|
||||||
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat,
|
FlatVectorsWriter flatVectorWriter,
|
||||||
int numMergeWorkers,
|
int numMergeWorkers,
|
||||||
ExecutorService mergeExec)
|
ExecutorService mergeExec)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.M = M;
|
this.M = M;
|
||||||
|
this.flatVectorWriter = flatVectorWriter;
|
||||||
this.beamWidth = beamWidth;
|
this.beamWidth = beamWidth;
|
||||||
this.numMergeWorkers = numMergeWorkers;
|
this.numMergeWorkers = numMergeWorkers;
|
||||||
this.mergeExec = mergeExec;
|
this.mergeExec = mergeExec;
|
||||||
segmentWriteState = state;
|
segmentWriteState = state;
|
||||||
|
|
||||||
String metaFileName =
|
String metaFileName =
|
||||||
IndexFileNames.segmentFileName(
|
IndexFileNames.segmentFileName(
|
||||||
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
||||||
|
|
||||||
String vectorDataFileName =
|
|
||||||
IndexFileNames.segmentFileName(
|
|
||||||
state.segmentInfo.name,
|
|
||||||
state.segmentSuffix,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION);
|
|
||||||
|
|
||||||
String indexDataFileName =
|
String indexDataFileName =
|
||||||
IndexFileNames.segmentFileName(
|
IndexFileNames.segmentFileName(
|
||||||
state.segmentInfo.name,
|
state.segmentInfo.name,
|
||||||
state.segmentSuffix,
|
state.segmentSuffix,
|
||||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
|
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
|
||||||
|
|
||||||
final String quantizedVectorDataFileName =
|
|
||||||
quantizedVectorsFormat != null
|
|
||||||
? IndexFileNames.segmentFileName(
|
|
||||||
state.segmentInfo.name,
|
|
||||||
state.segmentSuffix,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION)
|
|
||||||
: null;
|
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
meta = state.directory.createOutput(metaFileName, state.context);
|
meta = state.directory.createOutput(metaFileName, state.context);
|
||||||
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
|
|
||||||
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
|
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
|
||||||
|
|
||||||
CodecUtil.writeIndexHeader(
|
CodecUtil.writeIndexHeader(
|
||||||
|
@ -132,34 +108,12 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||||
state.segmentInfo.getId(),
|
state.segmentInfo.getId(),
|
||||||
state.segmentSuffix);
|
state.segmentSuffix);
|
||||||
CodecUtil.writeIndexHeader(
|
|
||||||
vectorData,
|
|
||||||
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME,
|
|
||||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
|
||||||
state.segmentInfo.getId(),
|
|
||||||
state.segmentSuffix);
|
|
||||||
CodecUtil.writeIndexHeader(
|
CodecUtil.writeIndexHeader(
|
||||||
vectorIndex,
|
vectorIndex,
|
||||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
|
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
|
||||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||||
state.segmentInfo.getId(),
|
state.segmentInfo.getId(),
|
||||||
state.segmentSuffix);
|
state.segmentSuffix);
|
||||||
if (quantizedVectorDataFileName != null) {
|
|
||||||
quantizedVectorData =
|
|
||||||
state.directory.createOutput(quantizedVectorDataFileName, state.context);
|
|
||||||
CodecUtil.writeIndexHeader(
|
|
||||||
quantizedVectorData,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME,
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
|
||||||
state.segmentInfo.getId(),
|
|
||||||
state.segmentSuffix);
|
|
||||||
quantizedVectorsWriter =
|
|
||||||
new Lucene99ScalarQuantizedVectorsWriter(
|
|
||||||
quantizedVectorData, quantizedVectorsFormat.quantile);
|
|
||||||
} else {
|
|
||||||
quantizedVectorData = null;
|
|
||||||
quantizedVectorsWriter = null;
|
|
||||||
}
|
|
||||||
success = true;
|
success = true;
|
||||||
} finally {
|
} finally {
|
||||||
if (success == false) {
|
if (success == false) {
|
||||||
|
@ -170,34 +124,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedVectorFieldWriter =
|
|
||||||
null;
|
|
||||||
// Quantization only supports FLOAT32 for now
|
|
||||||
if (quantizedVectorsWriter != null
|
|
||||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
|
||||||
quantizedVectorFieldWriter =
|
|
||||||
quantizedVectorsWriter.addField(fieldInfo, segmentWriteState.infoStream);
|
|
||||||
}
|
|
||||||
FieldWriter<?> newField =
|
FieldWriter<?> newField =
|
||||||
FieldWriter.create(
|
FieldWriter.create(fieldInfo, M, beamWidth, segmentWriteState.infoStream);
|
||||||
fieldInfo, M, beamWidth, segmentWriteState.infoStream, quantizedVectorFieldWriter);
|
|
||||||
fields.add(newField);
|
fields.add(newField);
|
||||||
return newField;
|
return flatVectorWriter.addField(fieldInfo, newField);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
|
flatVectorWriter.flush(maxDoc, sortMap);
|
||||||
for (FieldWriter<?> field : fields) {
|
for (FieldWriter<?> field : fields) {
|
||||||
long[] quantizedVectorOffsetAndLen = null;
|
|
||||||
if (field.quantizedWriter != null) {
|
|
||||||
assert quantizedVectorsWriter != null;
|
|
||||||
quantizedVectorOffsetAndLen =
|
|
||||||
quantizedVectorsWriter.flush(sortMap, field.quantizedWriter, field.docsWithField);
|
|
||||||
}
|
|
||||||
if (sortMap == null) {
|
if (sortMap == null) {
|
||||||
writeField(field, maxDoc, quantizedVectorOffsetAndLen);
|
writeField(field);
|
||||||
} else {
|
} else {
|
||||||
writeSortingField(field, maxDoc, sortMap, quantizedVectorOffsetAndLen);
|
writeSortingField(field, sortMap);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -208,40 +148,29 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
throw new IllegalStateException("already finished");
|
throw new IllegalStateException("already finished");
|
||||||
}
|
}
|
||||||
finished = true;
|
finished = true;
|
||||||
if (quantizedVectorsWriter != null) {
|
flatVectorWriter.finish();
|
||||||
quantizedVectorsWriter.finish();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (meta != null) {
|
if (meta != null) {
|
||||||
// write end of fields marker
|
// write end of fields marker
|
||||||
meta.writeInt(-1);
|
meta.writeInt(-1);
|
||||||
CodecUtil.writeFooter(meta);
|
CodecUtil.writeFooter(meta);
|
||||||
}
|
}
|
||||||
if (vectorData != null) {
|
if (vectorIndex != null) {
|
||||||
CodecUtil.writeFooter(vectorData);
|
|
||||||
CodecUtil.writeFooter(vectorIndex);
|
CodecUtil.writeFooter(vectorIndex);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
long total = 0;
|
long total = SHALLOW_RAM_BYTES_USED;
|
||||||
|
total += flatVectorWriter.ramBytesUsed();
|
||||||
for (FieldWriter<?> field : fields) {
|
for (FieldWriter<?> field : fields) {
|
||||||
total += field.ramBytesUsed();
|
total += field.ramBytesUsed();
|
||||||
}
|
}
|
||||||
return total;
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeField(FieldWriter<?> fieldData, int maxDoc, long[] quantizedVecOffsetAndLen)
|
private void writeField(FieldWriter<?> fieldData) throws IOException {
|
||||||
throws IOException {
|
|
||||||
// write vector values
|
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
|
||||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
|
||||||
case BYTE -> writeByteVectors(fieldData);
|
|
||||||
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
|
||||||
}
|
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
|
||||||
|
|
||||||
// write graph
|
// write graph
|
||||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||||
|
@ -249,43 +178,15 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
|
|
||||||
writeMeta(
|
writeMeta(
|
||||||
fieldData.isQuantized(),
|
|
||||||
fieldData.fieldInfo,
|
fieldData.fieldInfo,
|
||||||
maxDoc,
|
|
||||||
fieldData.getConfiguredQuantile(),
|
|
||||||
fieldData.getMinQuantile(),
|
|
||||||
fieldData.getMaxQuantile(),
|
|
||||||
quantizedVecOffsetAndLen,
|
|
||||||
vectorDataOffset,
|
|
||||||
vectorDataLength,
|
|
||||||
vectorIndexOffset,
|
vectorIndexOffset,
|
||||||
vectorIndexLength,
|
vectorIndexLength,
|
||||||
fieldData.docsWithField,
|
fieldData.docsWithField.cardinality(),
|
||||||
graph,
|
graph,
|
||||||
graphLevelNodeOffsets);
|
graphLevelNodeOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
private void writeSortingField(FieldWriter<?> fieldData, Sorter.DocMap sortMap)
|
||||||
final ByteBuffer buffer =
|
|
||||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
|
||||||
for (Object v : fieldData.vectors) {
|
|
||||||
buffer.asFloatBuffer().put((float[]) v);
|
|
||||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
|
||||||
for (Object v : fieldData.vectors) {
|
|
||||||
byte[] vector = (byte[]) v;
|
|
||||||
vectorData.writeBytes(vector, vector.length);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void writeSortingField(
|
|
||||||
FieldWriter<?> fieldData,
|
|
||||||
int maxDoc,
|
|
||||||
Sorter.DocMap sortMap,
|
|
||||||
long[] quantizedVectorOffsetAndLen)
|
|
||||||
throws IOException {
|
throws IOException {
|
||||||
final int[] docIdOffsets = new int[sortMap.size()];
|
final int[] docIdOffsets = new int[sortMap.size()];
|
||||||
int offset = 1; // 0 means no vector for this (field, document)
|
int offset = 1; // 0 means no vector for this (field, document)
|
||||||
|
@ -310,15 +211,6 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
doc++;
|
doc++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// write vector values
|
|
||||||
long vectorDataOffset =
|
|
||||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
|
||||||
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
|
|
||||||
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
|
|
||||||
};
|
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
|
||||||
|
|
||||||
// write graph
|
// write graph
|
||||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||||
|
@ -327,44 +219,14 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
|
|
||||||
writeMeta(
|
writeMeta(
|
||||||
fieldData.isQuantized(),
|
|
||||||
fieldData.fieldInfo,
|
fieldData.fieldInfo,
|
||||||
maxDoc,
|
|
||||||
fieldData.getConfiguredQuantile(),
|
|
||||||
fieldData.getMinQuantile(),
|
|
||||||
fieldData.getMaxQuantile(),
|
|
||||||
quantizedVectorOffsetAndLen,
|
|
||||||
vectorDataOffset,
|
|
||||||
vectorDataLength,
|
|
||||||
vectorIndexOffset,
|
vectorIndexOffset,
|
||||||
vectorIndexLength,
|
vectorIndexLength,
|
||||||
newDocsWithField,
|
fieldData.docsWithField.cardinality(),
|
||||||
mockGraph,
|
mockGraph,
|
||||||
graphLevelNodeOffsets);
|
graphLevelNodeOffsets);
|
||||||
}
|
}
|
||||||
|
|
||||||
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
|
|
||||||
throws IOException {
|
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
|
||||||
final ByteBuffer buffer =
|
|
||||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
|
||||||
for (int ordinal : ordMap) {
|
|
||||||
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
|
||||||
buffer.asFloatBuffer().put(vector);
|
|
||||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
|
||||||
}
|
|
||||||
return vectorDataOffset;
|
|
||||||
}
|
|
||||||
|
|
||||||
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
|
||||||
for (int ordinal : ordMap) {
|
|
||||||
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
|
||||||
vectorData.writeBytes(vector, vector.length);
|
|
||||||
}
|
|
||||||
return vectorDataOffset;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reconstructs the graph given the old and new node ids.
|
* Reconstructs the graph given the old and new node ids.
|
||||||
*
|
*
|
||||||
|
@ -475,116 +337,10 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
CloseableRandomVectorScorerSupplier scorerSupplier =
|
||||||
IndexOutput tempVectorData = null;
|
flatVectorWriter.mergeOneFieldToIndex(fieldInfo, mergeState);
|
||||||
IndexInput vectorDataInput = null;
|
|
||||||
CloseableRandomVectorScorerSupplier scorerSupplier = null;
|
|
||||||
boolean success = false;
|
boolean success = false;
|
||||||
try {
|
try {
|
||||||
ScalarQuantizer scalarQuantizer = null;
|
|
||||||
long[] quantizedVectorDataOffsetAndLength = null;
|
|
||||||
// If we have configured quantization and are FLOAT32
|
|
||||||
if (quantizedVectorsWriter != null
|
|
||||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
|
||||||
// We need the quantization parameters to write to the meta file
|
|
||||||
scalarQuantizer = quantizedVectorsWriter.mergeQuantiles(fieldInfo, mergeState);
|
|
||||||
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
|
||||||
segmentWriteState.infoStream.message(
|
|
||||||
QUANTIZED_VECTOR_COMPONENT,
|
|
||||||
"Merged quantiles field: "
|
|
||||||
+ fieldInfo.name
|
|
||||||
+ " newly merged quantile: "
|
|
||||||
+ scalarQuantizer);
|
|
||||||
}
|
|
||||||
assert scalarQuantizer != null;
|
|
||||||
quantizedVectorDataOffsetAndLength = new long[2];
|
|
||||||
quantizedVectorDataOffsetAndLength[0] = quantizedVectorData.alignFilePointer(Float.BYTES);
|
|
||||||
scorerSupplier =
|
|
||||||
quantizedVectorsWriter.mergeOneField(
|
|
||||||
segmentWriteState, fieldInfo, mergeState, scalarQuantizer);
|
|
||||||
quantizedVectorDataOffsetAndLength[1] =
|
|
||||||
quantizedVectorData.getFilePointer() - quantizedVectorDataOffsetAndLength[0];
|
|
||||||
}
|
|
||||||
final DocsWithFieldSet docsWithField;
|
|
||||||
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
|
||||||
|
|
||||||
// If we extract vector storage, this could be cleaner.
|
|
||||||
// But for now, vector storage & index creation/storage live together.
|
|
||||||
if (scorerSupplier == null) {
|
|
||||||
tempVectorData =
|
|
||||||
segmentWriteState.directory.createTempOutput(
|
|
||||||
vectorData.getName(), "temp", segmentWriteState.context);
|
|
||||||
docsWithField =
|
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
|
||||||
case BYTE -> writeByteVectorData(
|
|
||||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
|
||||||
case FLOAT32 -> writeVectorData(
|
|
||||||
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
|
||||||
};
|
|
||||||
CodecUtil.writeFooter(tempVectorData);
|
|
||||||
IOUtils.close(tempVectorData);
|
|
||||||
// copy the temporary file vectors to the actual data file
|
|
||||||
vectorDataInput =
|
|
||||||
segmentWriteState.directory.openInput(
|
|
||||||
tempVectorData.getName(), segmentWriteState.context);
|
|
||||||
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
|
||||||
CodecUtil.retrieveChecksum(vectorDataInput);
|
|
||||||
final RandomVectorScorerSupplier innerScoreSupplier =
|
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
|
||||||
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
|
||||||
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
|
||||||
fieldInfo.getVectorDimension(),
|
|
||||||
docsWithField.cardinality(),
|
|
||||||
vectorDataInput,
|
|
||||||
byteSize),
|
|
||||||
fieldInfo.getVectorSimilarityFunction());
|
|
||||||
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
|
|
||||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
|
||||||
fieldInfo.getVectorDimension(),
|
|
||||||
docsWithField.cardinality(),
|
|
||||||
vectorDataInput,
|
|
||||||
byteSize),
|
|
||||||
fieldInfo.getVectorSimilarityFunction());
|
|
||||||
};
|
|
||||||
final String tempFileName = tempVectorData.getName();
|
|
||||||
final IndexInput finalVectorDataInput = vectorDataInput;
|
|
||||||
scorerSupplier =
|
|
||||||
new CloseableRandomVectorScorerSupplier() {
|
|
||||||
boolean closed = false;
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
|
||||||
return innerScoreSupplier.scorer(ord);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void close() throws IOException {
|
|
||||||
if (closed) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
closed = true;
|
|
||||||
IOUtils.close(finalVectorDataInput);
|
|
||||||
segmentWriteState.directory.deleteFile(tempFileName);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public RandomVectorScorerSupplier copy() throws IOException {
|
|
||||||
// here we just return the inner out since we only need to close this outside copy
|
|
||||||
return innerScoreSupplier.copy();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} else {
|
|
||||||
// No need to use temporary file as we don't have to re-open for reading
|
|
||||||
docsWithField =
|
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
|
||||||
case BYTE -> writeByteVectorData(
|
|
||||||
vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
|
||||||
case FLOAT32 -> writeVectorData(
|
|
||||||
vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
|
||||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||||
// build the graph using the temporary vector data
|
// build the graph using the temporary vector data
|
||||||
// we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
// we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||||
|
@ -592,7 +348,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
// TODO: separate random access vector values from DocIdSetIterator?
|
// TODO: separate random access vector values from DocIdSetIterator?
|
||||||
OnHeapHnswGraph graph = null;
|
OnHeapHnswGraph graph = null;
|
||||||
int[][] vectorIndexNodeOffsets = null;
|
int[][] vectorIndexNodeOffsets = null;
|
||||||
if (docsWithField.cardinality() != 0) {
|
if (scorerSupplier.totalVectorCount() > 0) {
|
||||||
// build graph
|
// build graph
|
||||||
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
|
HnswGraphMerger merger = createGraphMerger(fieldInfo, scorerSupplier);
|
||||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||||
|
@ -608,23 +364,17 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
graph =
|
graph =
|
||||||
merger.merge(
|
merger.merge(
|
||||||
mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality());
|
mergedVectorIterator,
|
||||||
|
segmentWriteState.infoStream,
|
||||||
|
scorerSupplier.totalVectorCount());
|
||||||
vectorIndexNodeOffsets = writeGraph(graph);
|
vectorIndexNodeOffsets = writeGraph(graph);
|
||||||
}
|
}
|
||||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||||
writeMeta(
|
writeMeta(
|
||||||
scalarQuantizer != null,
|
|
||||||
fieldInfo,
|
fieldInfo,
|
||||||
segmentWriteState.segmentInfo.maxDoc(),
|
|
||||||
scalarQuantizer == null ? null : scalarQuantizer.getConfiguredQuantile(),
|
|
||||||
scalarQuantizer == null ? null : scalarQuantizer.getLowerQuantile(),
|
|
||||||
scalarQuantizer == null ? null : scalarQuantizer.getUpperQuantile(),
|
|
||||||
quantizedVectorDataOffsetAndLength,
|
|
||||||
vectorDataOffset,
|
|
||||||
vectorDataLength,
|
|
||||||
vectorIndexOffset,
|
vectorIndexOffset,
|
||||||
vectorIndexLength,
|
vectorIndexLength,
|
||||||
docsWithField,
|
scorerSupplier.totalVectorCount(),
|
||||||
graph,
|
graph,
|
||||||
vectorIndexNodeOffsets);
|
vectorIndexNodeOffsets);
|
||||||
success = true;
|
success = true;
|
||||||
|
@ -632,11 +382,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
if (success) {
|
if (success) {
|
||||||
IOUtils.close(scorerSupplier);
|
IOUtils.close(scorerSupplier);
|
||||||
} else {
|
} else {
|
||||||
IOUtils.closeWhileHandlingException(scorerSupplier, vectorDataInput, tempVectorData);
|
IOUtils.closeWhileHandlingException(scorerSupplier);
|
||||||
if (tempVectorData != null) {
|
|
||||||
IOUtils.deleteFilesIgnoringExceptions(
|
|
||||||
segmentWriteState.directory, tempVectorData.getName());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -652,7 +398,7 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
int countOnLevel0 = graph.size();
|
int countOnLevel0 = graph.size();
|
||||||
int[][] offsets = new int[graph.numLevels()][];
|
int[][] offsets = new int[graph.numLevels()][];
|
||||||
for (int level = 0; level < graph.numLevels(); level++) {
|
for (int level = 0; level < graph.numLevels(); level++) {
|
||||||
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
|
int[] sortedNodes = NodesIterator.getSortedNodes(graph.getNodesOnLevel(level));
|
||||||
offsets[level] = new int[sortedNodes.length];
|
offsets[level] = new int[sortedNodes.length];
|
||||||
int nodeOffsetId = 0;
|
int nodeOffsetId = 0;
|
||||||
for (int node : sortedNodes) {
|
for (int node : sortedNodes) {
|
||||||
|
@ -680,80 +426,21 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
return offsets;
|
return offsets;
|
||||||
}
|
}
|
||||||
|
|
||||||
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
|
|
||||||
int[] sortedNodes = new int[nodesOnLevel.size()];
|
|
||||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
|
||||||
sortedNodes[n] = nodesOnLevel.nextInt();
|
|
||||||
}
|
|
||||||
Arrays.sort(sortedNodes);
|
|
||||||
return sortedNodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
private HnswGraphMerger createGraphMerger(
|
|
||||||
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
|
|
||||||
if (mergeExec != null) {
|
|
||||||
return new ConcurrentHnswMerger(
|
|
||||||
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
|
|
||||||
}
|
|
||||||
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void writeMeta(
|
private void writeMeta(
|
||||||
boolean isQuantized,
|
|
||||||
FieldInfo field,
|
FieldInfo field,
|
||||||
int maxDoc,
|
|
||||||
Float configuredQuantizationQuantile,
|
|
||||||
Float lowerQuantile,
|
|
||||||
Float upperQuantile,
|
|
||||||
long[] quantizedVectorDataOffsetAndLen,
|
|
||||||
long vectorDataOffset,
|
|
||||||
long vectorDataLength,
|
|
||||||
long vectorIndexOffset,
|
long vectorIndexOffset,
|
||||||
long vectorIndexLength,
|
long vectorIndexLength,
|
||||||
DocsWithFieldSet docsWithField,
|
int count,
|
||||||
HnswGraph graph,
|
HnswGraph graph,
|
||||||
int[][] graphLevelNodeOffsets)
|
int[][] graphLevelNodeOffsets)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
meta.writeInt(field.number);
|
meta.writeInt(field.number);
|
||||||
meta.writeInt(field.getVectorEncoding().ordinal());
|
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||||
meta.writeByte(isQuantized ? (byte) 1 : (byte) 0);
|
|
||||||
if (isQuantized) {
|
|
||||||
assert lowerQuantile != null
|
|
||||||
&& upperQuantile != null
|
|
||||||
&& quantizedVectorDataOffsetAndLen != null;
|
|
||||||
assert quantizedVectorDataOffsetAndLen.length == 2;
|
|
||||||
meta.writeInt(
|
|
||||||
Float.floatToIntBits(
|
|
||||||
configuredQuantizationQuantile != null
|
|
||||||
? configuredQuantizationQuantile
|
|
||||||
: calculateDefaultQuantile(field.getVectorDimension())));
|
|
||||||
meta.writeInt(Float.floatToIntBits(lowerQuantile));
|
|
||||||
meta.writeInt(Float.floatToIntBits(upperQuantile));
|
|
||||||
meta.writeVLong(quantizedVectorDataOffsetAndLen[0]);
|
|
||||||
meta.writeVLong(quantizedVectorDataOffsetAndLen[1]);
|
|
||||||
} else {
|
|
||||||
assert configuredQuantizationQuantile == null
|
|
||||||
&& lowerQuantile == null
|
|
||||||
&& upperQuantile == null
|
|
||||||
&& quantizedVectorDataOffsetAndLen == null;
|
|
||||||
}
|
|
||||||
meta.writeVLong(vectorDataOffset);
|
|
||||||
meta.writeVLong(vectorDataLength);
|
|
||||||
meta.writeVLong(vectorIndexOffset);
|
meta.writeVLong(vectorIndexOffset);
|
||||||
meta.writeVLong(vectorIndexLength);
|
meta.writeVLong(vectorIndexLength);
|
||||||
meta.writeVInt(field.getVectorDimension());
|
meta.writeVInt(field.getVectorDimension());
|
||||||
|
|
||||||
// write docIDs
|
|
||||||
int count = docsWithField.cardinality();
|
|
||||||
meta.writeInt(count);
|
meta.writeInt(count);
|
||||||
if (isQuantized) {
|
|
||||||
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
|
||||||
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
|
|
||||||
}
|
|
||||||
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
|
||||||
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
|
|
||||||
|
|
||||||
meta.writeVInt(M);
|
meta.writeVInt(M);
|
||||||
// write graph nodes on each level
|
// write graph nodes on each level
|
||||||
if (graph == null) {
|
if (graph == null) {
|
||||||
|
@ -799,109 +486,47 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
private HnswGraphMerger createGraphMerger(
|
||||||
* Writes the byte vector values to the output and returns a set of documents that contains
|
FieldInfo fieldInfo, RandomVectorScorerSupplier scorerSupplier) {
|
||||||
* vectors.
|
if (mergeExec != null) {
|
||||||
*/
|
return new ConcurrentHnswMerger(
|
||||||
private static DocsWithFieldSet writeByteVectorData(
|
fieldInfo, scorerSupplier, M, beamWidth, mergeExec, numMergeWorkers);
|
||||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
|
||||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
|
||||||
for (int docV = byteVectorValues.nextDoc();
|
|
||||||
docV != NO_MORE_DOCS;
|
|
||||||
docV = byteVectorValues.nextDoc()) {
|
|
||||||
// write vector
|
|
||||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
|
||||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
|
||||||
output.writeBytes(binaryValue, binaryValue.length);
|
|
||||||
docsWithField.add(docV);
|
|
||||||
}
|
}
|
||||||
return docsWithField;
|
return new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
|
||||||
*/
|
|
||||||
private static DocsWithFieldSet writeVectorData(
|
|
||||||
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
|
|
||||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
|
||||||
ByteBuffer buffer =
|
|
||||||
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
|
||||||
.order(ByteOrder.LITTLE_ENDIAN);
|
|
||||||
for (int docV = floatVectorValues.nextDoc();
|
|
||||||
docV != NO_MORE_DOCS;
|
|
||||||
docV = floatVectorValues.nextDoc()) {
|
|
||||||
// write vector
|
|
||||||
float[] value = floatVectorValues.vectorValue();
|
|
||||||
buffer.asFloatBuffer().put(value);
|
|
||||||
output.writeBytes(buffer.array(), buffer.limit());
|
|
||||||
docsWithField.add(docV);
|
|
||||||
}
|
|
||||||
return docsWithField;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
|
IOUtils.close(meta, vectorIndex, flatVectorWriter);
|
||||||
}
|
}
|
||||||
|
|
||||||
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
private static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||||
|
|
||||||
|
private static final long SHALLOW_SIZE =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(FieldWriter.class);
|
||||||
|
|
||||||
private final FieldInfo fieldInfo;
|
private final FieldInfo fieldInfo;
|
||||||
private final int dim;
|
|
||||||
private final DocsWithFieldSet docsWithField;
|
private final DocsWithFieldSet docsWithField;
|
||||||
private final List<T> vectors;
|
private final List<T> vectors;
|
||||||
private final HnswGraphBuilder hnswGraphBuilder;
|
private final HnswGraphBuilder hnswGraphBuilder;
|
||||||
private final Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter
|
|
||||||
quantizedWriter;
|
|
||||||
|
|
||||||
private int lastDocID = -1;
|
private int lastDocID = -1;
|
||||||
private int node = 0;
|
private int node = 0;
|
||||||
|
|
||||||
static FieldWriter<?> create(
|
static FieldWriter<?> create(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||||
FieldInfo fieldInfo,
|
|
||||||
int M,
|
|
||||||
int beamWidth,
|
|
||||||
InfoStream infoStream,
|
|
||||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter writer)
|
|
||||||
throws IOException {
|
throws IOException {
|
||||||
int dim = fieldInfo.getVectorDimension();
|
|
||||||
return switch (fieldInfo.getVectorEncoding()) {
|
return switch (fieldInfo.getVectorEncoding()) {
|
||||||
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream, writer) {
|
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream);
|
||||||
@Override
|
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream);
|
||||||
public byte[] copyValue(byte[] value) {
|
|
||||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream, writer) {
|
|
||||||
@Override
|
|
||||||
public float[] copyValue(float[] value) {
|
|
||||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@SuppressWarnings("unchecked")
|
@SuppressWarnings("unchecked")
|
||||||
FieldWriter(
|
FieldWriter(FieldInfo fieldInfo, int M, int beamWidth, InfoStream infoStream)
|
||||||
FieldInfo fieldInfo,
|
|
||||||
int M,
|
|
||||||
int beamWidth,
|
|
||||||
InfoStream infoStream,
|
|
||||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedWriter)
|
|
||||||
throws IOException {
|
throws IOException {
|
||||||
this.fieldInfo = fieldInfo;
|
this.fieldInfo = fieldInfo;
|
||||||
this.dim = fieldInfo.getVectorDimension();
|
|
||||||
this.docsWithField = new DocsWithFieldSet();
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
this.quantizedWriter = quantizedWriter;
|
|
||||||
vectors = new ArrayList<>();
|
vectors = new ArrayList<>();
|
||||||
if (quantizedWriter != null
|
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, fieldInfo.getVectorDimension());
|
||||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Vector encoding ["
|
|
||||||
+ VectorEncoding.FLOAT32
|
|
||||||
+ "] required for quantized vectors; provided="
|
|
||||||
+ fieldInfo.getVectorEncoding());
|
|
||||||
}
|
|
||||||
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
|
|
||||||
RandomVectorScorerSupplier scorerSupplier =
|
RandomVectorScorerSupplier scorerSupplier =
|
||||||
switch (fieldInfo.getVectorEncoding()) {
|
switch (fieldInfo.getVectorEncoding()) {
|
||||||
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
||||||
|
@ -925,20 +550,20 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||||
}
|
}
|
||||||
assert docID > lastDocID;
|
assert docID > lastDocID;
|
||||||
T copy = copyValue(vectorValue);
|
vectors.add(vectorValue);
|
||||||
if (quantizedWriter != null) {
|
|
||||||
assert vectorValue instanceof float[];
|
|
||||||
quantizedWriter.addValue((float[]) copy);
|
|
||||||
}
|
|
||||||
docsWithField.add(docID);
|
docsWithField.add(docID);
|
||||||
vectors.add(copy);
|
|
||||||
hnswGraphBuilder.addGraphNode(node);
|
hnswGraphBuilder.addGraphNode(node);
|
||||||
node++;
|
node++;
|
||||||
lastDocID = docID;
|
lastDocID = docID;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public T copyValue(T vectorValue) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
OnHeapHnswGraph getGraph() {
|
OnHeapHnswGraph getGraph() {
|
||||||
if (vectors.size() > 0) {
|
if (node > 0) {
|
||||||
return hnswGraphBuilder.getGraph();
|
return hnswGraphBuilder.getGraph();
|
||||||
} else {
|
} else {
|
||||||
return null;
|
return null;
|
||||||
|
@ -947,32 +572,11 @@ public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
if (vectors.size() == 0) return 0;
|
return SHALLOW_SIZE
|
||||||
long quantizationSpace = quantizedWriter != null ? quantizedWriter.ramBytesUsed() : 0L;
|
+ docsWithField.ramBytesUsed()
|
||||||
return docsWithField.ramBytesUsed()
|
|
||||||
+ (long) vectors.size()
|
+ (long) vectors.size()
|
||||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||||
+ (long) vectors.size()
|
+ hnswGraphBuilder.getGraph().ramBytesUsed();
|
||||||
* fieldInfo.getVectorDimension()
|
|
||||||
* fieldInfo.getVectorEncoding().byteSize
|
|
||||||
+ hnswGraphBuilder.getGraph().ramBytesUsed()
|
|
||||||
+ quantizationSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
Float getConfiguredQuantile() {
|
|
||||||
return quantizedWriter == null ? null : quantizedWriter.getQuantile();
|
|
||||||
}
|
|
||||||
|
|
||||||
Float getMinQuantile() {
|
|
||||||
return quantizedWriter == null ? null : quantizedWriter.getMinQuantile();
|
|
||||||
}
|
|
||||||
|
|
||||||
Float getMaxQuantile() {
|
|
||||||
return quantizedWriter == null ? null : quantizedWriter.getMaxQuantile();
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean isQuantized() {
|
|
||||||
return quantizedWriter != null;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -17,20 +17,31 @@
|
||||||
|
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsReader;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Format supporting vector quantization, storage, and retrieval
|
* Format supporting vector quantization, storage, and retrieval
|
||||||
*
|
*
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene99ScalarQuantizedVectorsFormat {
|
public final class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {
|
||||||
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
|
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
|
||||||
|
|
||||||
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
|
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
|
||||||
|
|
||||||
static final int VERSION_START = 0;
|
static final int VERSION_START = 0;
|
||||||
static final int VERSION_CURRENT = VERSION_START;
|
static final int VERSION_CURRENT = VERSION_START;
|
||||||
static final String QUANTIZED_VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsData";
|
static final String META_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatMeta";
|
||||||
static final String QUANTIZED_VECTOR_DATA_EXTENSION = "veq";
|
static final String VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsFormatData";
|
||||||
|
static final String META_EXTENSION = "vemq";
|
||||||
|
static final String VECTOR_DATA_EXTENSION = "veq";
|
||||||
|
|
||||||
|
private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat();
|
||||||
|
|
||||||
/** The minimum quantile */
|
/** The minimum quantile */
|
||||||
private static final float MINIMUM_QUANTILE = 0.9f;
|
private static final float MINIMUM_QUANTILE = 0.9f;
|
||||||
|
@ -74,6 +85,24 @@ public final class Lucene99ScalarQuantizedVectorsFormat {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String toString() {
|
public String toString() {
|
||||||
return NAME + "(name=" + NAME + ", quantile=" + quantile + ")";
|
return NAME
|
||||||
|
+ "(name="
|
||||||
|
+ NAME
|
||||||
|
+ ", quantile="
|
||||||
|
+ quantile
|
||||||
|
+ ", rawVectorFormat="
|
||||||
|
+ rawVectorFormat
|
||||||
|
+ ")";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||||
|
return new Lucene99ScalarQuantizedVectorsWriter(
|
||||||
|
state, quantile, rawVectorFormat.fieldsWriter(state));
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public FlatVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||||
|
return new Lucene99ScalarQuantizedVectorsReader(state, rawVectorFormat.fieldsReader(state));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,46 +18,128 @@
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsReader;
|
||||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||||
|
import org.apache.lucene.index.ByteVectorValues;
|
||||||
|
import org.apache.lucene.index.CorruptIndexException;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
|
import org.apache.lucene.index.FieldInfos;
|
||||||
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
|
import org.apache.lucene.index.SegmentReadState;
|
||||||
|
import org.apache.lucene.index.VectorEncoding;
|
||||||
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.store.ChecksumIndexInput;
|
||||||
|
import org.apache.lucene.store.DataInput;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
|
import org.apache.lucene.util.Accountable;
|
||||||
|
import org.apache.lucene.util.IOUtils;
|
||||||
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
import org.apache.lucene.util.ScalarQuantizer;
|
||||||
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Reads Scalar Quantized vectors from the index segments along with index data structures.
|
* Reads Scalar Quantized vectors from the index segments along with index data structures.
|
||||||
*
|
*
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene99ScalarQuantizedVectorsReader {
|
public final class Lucene99ScalarQuantizedVectorsReader extends FlatVectorsReader
|
||||||
|
implements QuantizedVectorsReader {
|
||||||
|
|
||||||
|
private static final long SHALLOW_SIZE =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsReader.class);
|
||||||
|
|
||||||
|
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||||
private final IndexInput quantizedVectorData;
|
private final IndexInput quantizedVectorData;
|
||||||
|
private final FlatVectorsReader rawVectorsReader;
|
||||||
|
|
||||||
Lucene99ScalarQuantizedVectorsReader(IndexInput quantizedVectorData) {
|
Lucene99ScalarQuantizedVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader)
|
||||||
this.quantizedVectorData = quantizedVectorData;
|
throws IOException {
|
||||||
|
int versionMeta = -1;
|
||||||
|
String metaFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
|
||||||
|
boolean success = false;
|
||||||
|
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
|
||||||
|
Throwable priorE = null;
|
||||||
|
try {
|
||||||
|
versionMeta =
|
||||||
|
CodecUtil.checkIndexHeader(
|
||||||
|
meta,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
readFields(meta, state.fieldInfos);
|
||||||
|
} catch (Throwable exception) {
|
||||||
|
priorE = exception;
|
||||||
|
} finally {
|
||||||
|
try {
|
||||||
|
CodecUtil.checkFooter(meta, priorE);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.close(rawVectorsReader);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
success = false;
|
||||||
|
this.rawVectorsReader = rawVectorsReader;
|
||||||
|
try {
|
||||||
|
quantizedVectorData =
|
||||||
|
openDataInput(
|
||||||
|
state,
|
||||||
|
versionMeta,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void validateFieldEntry(
|
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
||||||
FieldInfo info, int fieldDimension, int size, long quantizedVectorDataLength) {
|
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||||
|
FieldInfo info = infos.fieldInfo(fieldNumber);
|
||||||
|
if (info == null) {
|
||||||
|
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||||
|
}
|
||||||
|
FieldEntry fieldEntry = readField(meta);
|
||||||
|
validateFieldEntry(info, fieldEntry);
|
||||||
|
fields.put(info.name, fieldEntry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
|
||||||
int dimension = info.getVectorDimension();
|
int dimension = info.getVectorDimension();
|
||||||
if (dimension != fieldDimension) {
|
if (dimension != fieldEntry.dimension) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Inconsistent vector dimension for field=\""
|
"Inconsistent vector dimension for field=\""
|
||||||
+ info.name
|
+ info.name
|
||||||
+ "\"; "
|
+ "\"; "
|
||||||
+ dimension
|
+ dimension
|
||||||
+ " != "
|
+ " != "
|
||||||
+ fieldDimension);
|
+ fieldEntry.dimension);
|
||||||
}
|
}
|
||||||
|
|
||||||
// int8 quantized and calculated stored offset.
|
// int8 quantized and calculated stored offset.
|
||||||
long quantizedVectorBytes = dimension + Float.BYTES;
|
long quantizedVectorBytes = dimension + Float.BYTES;
|
||||||
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, size);
|
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, fieldEntry.size);
|
||||||
if (numQuantizedVectorBytes != quantizedVectorDataLength) {
|
if (numQuantizedVectorBytes != fieldEntry.vectorDataLength) {
|
||||||
throw new IllegalStateException(
|
throw new IllegalStateException(
|
||||||
"Quantized vector data length "
|
"Quantized vector data length "
|
||||||
+ quantizedVectorDataLength
|
+ fieldEntry.vectorDataLength
|
||||||
+ " not matching size="
|
+ " not matching size="
|
||||||
+ size
|
+ fieldEntry.size
|
||||||
+ " * (dim="
|
+ " * (dim="
|
||||||
+ dimension
|
+ dimension
|
||||||
+ " + 4)"
|
+ " + 4)"
|
||||||
|
@ -66,23 +148,184 @@ public final class Lucene99ScalarQuantizedVectorsReader {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
public void checkIntegrity() throws IOException {
|
public void checkIntegrity() throws IOException {
|
||||||
|
rawVectorsReader.checkIntegrity();
|
||||||
CodecUtil.checksumEntireFile(quantizedVectorData);
|
CodecUtil.checksumEntireFile(quantizedVectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
OffHeapQuantizedByteVectorValues getQuantizedVectorValues(
|
@Override
|
||||||
OrdToDocDISIReaderConfiguration configuration,
|
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||||
int dimension,
|
return rawVectorsReader.getFloatVectorValues(field);
|
||||||
int size,
|
}
|
||||||
long quantizedVectorDataOffset,
|
|
||||||
long quantizedVectorDataLength)
|
@Override
|
||||||
|
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||||
|
return rawVectorsReader.getByteVectorValues(field);
|
||||||
|
}
|
||||||
|
|
||||||
|
private static IndexInput openDataInput(
|
||||||
|
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
|
String fileName =
|
||||||
|
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||||
|
IndexInput in = state.directory.openInput(fileName, state.context);
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
int versionVectorData =
|
||||||
|
CodecUtil.checkIndexHeader(
|
||||||
|
in,
|
||||||
|
codecName,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_START,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
if (versionMeta != versionVectorData) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"Format versions mismatch: meta="
|
||||||
|
+ versionMeta
|
||||||
|
+ ", "
|
||||||
|
+ codecName
|
||||||
|
+ "="
|
||||||
|
+ versionVectorData,
|
||||||
|
in);
|
||||||
|
}
|
||||||
|
CodecUtil.retrieveChecksum(in);
|
||||||
|
success = true;
|
||||||
|
return in;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(in);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(String field, float[] target) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(field);
|
||||||
|
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
if (fieldEntry.scalarQuantizer == null) {
|
||||||
|
return rawVectorsReader.getRandomVectorScorer(field, target);
|
||||||
|
}
|
||||||
|
OffHeapQuantizedByteVectorValues vectorValues =
|
||||||
|
OffHeapQuantizedByteVectorValues.load(
|
||||||
|
fieldEntry.ordToDoc,
|
||||||
|
fieldEntry.dimension,
|
||||||
|
fieldEntry.size,
|
||||||
|
fieldEntry.vectorDataOffset,
|
||||||
|
fieldEntry.vectorDataLength,
|
||||||
|
quantizedVectorData);
|
||||||
|
return new ScalarQuantizedRandomVectorScorer(
|
||||||
|
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public RandomVectorScorer getRandomVectorScorer(String field, byte[] target) throws IOException {
|
||||||
|
return rawVectorsReader.getRandomVectorScorer(field, target);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void close() throws IOException {
|
||||||
|
IOUtils.close(quantizedVectorData, rawVectorsReader);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
long size = SHALLOW_SIZE;
|
||||||
|
size +=
|
||||||
|
RamUsageEstimator.sizeOfMap(
|
||||||
|
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
||||||
|
size += rawVectorsReader.ramBytesUsed();
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
private FieldEntry readField(IndexInput input) throws IOException {
|
||||||
|
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||||
|
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||||
|
return new FieldEntry(input, vectorEncoding, similarityFunction);
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||||
|
int similarityFunctionId = input.readInt();
|
||||||
|
if (similarityFunctionId < 0
|
||||||
|
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
|
||||||
|
throw new CorruptIndexException(
|
||||||
|
"Invalid similarity function id: " + similarityFunctionId, input);
|
||||||
|
}
|
||||||
|
return VectorSimilarityFunction.values()[similarityFunctionId];
|
||||||
|
}
|
||||||
|
|
||||||
|
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||||
|
int encodingId = input.readInt();
|
||||||
|
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||||
|
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||||
|
}
|
||||||
|
return VectorEncoding.values()[encodingId];
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException {
|
||||||
|
FieldEntry fieldEntry = fields.get(fieldName);
|
||||||
|
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
return OffHeapQuantizedByteVectorValues.load(
|
return OffHeapQuantizedByteVectorValues.load(
|
||||||
configuration,
|
fieldEntry.ordToDoc,
|
||||||
dimension,
|
fieldEntry.dimension,
|
||||||
size,
|
fieldEntry.size,
|
||||||
quantizedVectorDataOffset,
|
fieldEntry.vectorDataOffset,
|
||||||
quantizedVectorDataLength,
|
fieldEntry.vectorDataLength,
|
||||||
quantizedVectorData);
|
quantizedVectorData);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScalarQuantizer getQuantizationState(String fieldName) {
|
||||||
|
FieldEntry fieldEntry = fields.get(fieldName);
|
||||||
|
if (fieldEntry == null || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return fieldEntry.scalarQuantizer;
|
||||||
|
}
|
||||||
|
|
||||||
|
private static class FieldEntry implements Accountable {
|
||||||
|
private static final long SHALLOW_SIZE =
|
||||||
|
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
|
||||||
|
final VectorSimilarityFunction similarityFunction;
|
||||||
|
final VectorEncoding vectorEncoding;
|
||||||
|
final int dimension;
|
||||||
|
final long vectorDataOffset;
|
||||||
|
final long vectorDataLength;
|
||||||
|
final ScalarQuantizer scalarQuantizer;
|
||||||
|
final int size;
|
||||||
|
final OrdToDocDISIReaderConfiguration ordToDoc;
|
||||||
|
|
||||||
|
FieldEntry(
|
||||||
|
IndexInput input,
|
||||||
|
VectorEncoding vectorEncoding,
|
||||||
|
VectorSimilarityFunction similarityFunction)
|
||||||
|
throws IOException {
|
||||||
|
this.similarityFunction = similarityFunction;
|
||||||
|
this.vectorEncoding = vectorEncoding;
|
||||||
|
vectorDataOffset = input.readVLong();
|
||||||
|
vectorDataLength = input.readVLong();
|
||||||
|
dimension = input.readVInt();
|
||||||
|
size = input.readInt();
|
||||||
|
if (size > 0) {
|
||||||
|
float configuredQuantile = Float.intBitsToFloat(input.readInt());
|
||||||
|
float minQuantile = Float.intBitsToFloat(input.readInt());
|
||||||
|
float maxQuantile = Float.intBitsToFloat(input.readInt());
|
||||||
|
scalarQuantizer = new ScalarQuantizer(minQuantile, maxQuantile, configuredQuantile);
|
||||||
|
} else {
|
||||||
|
scalarQuantizer = null;
|
||||||
|
}
|
||||||
|
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public long ramBytesUsed() {
|
||||||
|
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(ordToDoc);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
|
|
||||||
package org.apache.lucene.codecs.lucene99;
|
package org.apache.lucene.codecs.lucene99;
|
||||||
|
|
||||||
|
import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
||||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
|
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
|
||||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||||
|
@ -29,13 +30,18 @@ import java.nio.ByteOrder;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.apache.lucene.codecs.CodecUtil;
|
import org.apache.lucene.codecs.CodecUtil;
|
||||||
|
import org.apache.lucene.codecs.FlatFieldVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.FlatVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||||
|
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
import org.apache.lucene.index.DocIDMerger;
|
import org.apache.lucene.index.DocIDMerger;
|
||||||
import org.apache.lucene.index.DocsWithFieldSet;
|
import org.apache.lucene.index.DocsWithFieldSet;
|
||||||
import org.apache.lucene.index.FieldInfo;
|
import org.apache.lucene.index.FieldInfo;
|
||||||
import org.apache.lucene.index.FloatVectorValues;
|
import org.apache.lucene.index.FloatVectorValues;
|
||||||
|
import org.apache.lucene.index.IndexFileNames;
|
||||||
import org.apache.lucene.index.MergeState;
|
import org.apache.lucene.index.MergeState;
|
||||||
import org.apache.lucene.index.SegmentWriteState;
|
import org.apache.lucene.index.SegmentWriteState;
|
||||||
import org.apache.lucene.index.Sorter;
|
import org.apache.lucene.index.Sorter;
|
||||||
|
@ -44,7 +50,6 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.search.DocIdSetIterator;
|
import org.apache.lucene.search.DocIdSetIterator;
|
||||||
import org.apache.lucene.store.IndexInput;
|
import org.apache.lucene.store.IndexInput;
|
||||||
import org.apache.lucene.store.IndexOutput;
|
import org.apache.lucene.store.IndexOutput;
|
||||||
import org.apache.lucene.util.Accountable;
|
|
||||||
import org.apache.lucene.util.IOUtils;
|
import org.apache.lucene.util.IOUtils;
|
||||||
import org.apache.lucene.util.InfoStream;
|
import org.apache.lucene.util.InfoStream;
|
||||||
import org.apache.lucene.util.RamUsageEstimator;
|
import org.apache.lucene.util.RamUsageEstimator;
|
||||||
|
@ -59,9 +64,9 @@ import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||||
*
|
*
|
||||||
* @lucene.experimental
|
* @lucene.experimental
|
||||||
*/
|
*/
|
||||||
public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWriter {
|
||||||
|
|
||||||
private static final long BASE_RAM_BYTES_USED =
|
private static final long SHALLOW_RAM_BYTES_USED =
|
||||||
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
|
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
|
||||||
|
|
||||||
// Used for determining when merged quantiles shifted too far from individual segment quantiles.
|
// Used for determining when merged quantiles shifted too far from individual segment quantiles.
|
||||||
|
@ -82,67 +87,214 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
// the quantization error) and the condition is sensitive enough to detect all adversarial cases,
|
// the quantization error) and the condition is sensitive enough to detect all adversarial cases,
|
||||||
// such as merging clustered data.
|
// such as merging clustered data.
|
||||||
private static final float REQUANTIZATION_LIMIT = 0.2f;
|
private static final float REQUANTIZATION_LIMIT = 0.2f;
|
||||||
private final IndexOutput quantizedVectorData;
|
private final SegmentWriteState segmentWriteState;
|
||||||
|
|
||||||
|
private final List<FieldWriter> fields = new ArrayList<>();
|
||||||
|
private final IndexOutput meta, quantizedVectorData;
|
||||||
private final Float quantile;
|
private final Float quantile;
|
||||||
|
private final FlatVectorsWriter rawVectorDelegate;
|
||||||
private boolean finished;
|
private boolean finished;
|
||||||
|
|
||||||
Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) {
|
Lucene99ScalarQuantizedVectorsWriter(
|
||||||
this.quantile = quantile;
|
SegmentWriteState state, Float quantile, FlatVectorsWriter rawVectorDelegate)
|
||||||
this.quantizedVectorData = quantizedVectorData;
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantizationFieldVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) {
|
|
||||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
|
||||||
throw new IllegalArgumentException(
|
|
||||||
"Only float32 vector fields are supported for quantization");
|
|
||||||
}
|
|
||||||
float quantile =
|
|
||||||
this.quantile == null
|
|
||||||
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
|
||||||
: this.quantile;
|
|
||||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
|
||||||
infoStream.message(
|
|
||||||
QUANTIZED_VECTOR_COMPONENT,
|
|
||||||
"quantizing field="
|
|
||||||
+ fieldInfo.name
|
|
||||||
+ " dimension="
|
|
||||||
+ fieldInfo.getVectorDimension()
|
|
||||||
+ " quantile="
|
|
||||||
+ quantile);
|
|
||||||
}
|
|
||||||
return QuantizationFieldVectorWriter.create(fieldInfo, quantile, infoStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
long[] flush(
|
|
||||||
Sorter.DocMap sortMap, QuantizationFieldVectorWriter field, DocsWithFieldSet docsWithField)
|
|
||||||
throws IOException {
|
throws IOException {
|
||||||
field.finish();
|
this.quantile = quantile;
|
||||||
return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField);
|
segmentWriteState = state;
|
||||||
|
String metaFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.META_EXTENSION);
|
||||||
|
|
||||||
|
String quantizedVectorDataFileName =
|
||||||
|
IndexFileNames.segmentFileName(
|
||||||
|
state.segmentInfo.name,
|
||||||
|
state.segmentSuffix,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_EXTENSION);
|
||||||
|
this.rawVectorDelegate = rawVectorDelegate;
|
||||||
|
boolean success = false;
|
||||||
|
try {
|
||||||
|
meta = state.directory.createOutput(metaFileName, state.context);
|
||||||
|
quantizedVectorData =
|
||||||
|
state.directory.createOutput(quantizedVectorDataFileName, state.context);
|
||||||
|
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
meta,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.META_CODEC_NAME,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
CodecUtil.writeIndexHeader(
|
||||||
|
quantizedVectorData,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VECTOR_DATA_CODEC_NAME,
|
||||||
|
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||||
|
state.segmentInfo.getId(),
|
||||||
|
state.segmentSuffix);
|
||||||
|
success = true;
|
||||||
|
} finally {
|
||||||
|
if (success == false) {
|
||||||
|
IOUtils.closeWhileHandlingException(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void finish() throws IOException {
|
@Override
|
||||||
|
public FlatFieldVectorsWriter<?> addField(
|
||||||
|
FieldInfo fieldInfo, KnnFieldVectorsWriter<?> indexWriter) throws IOException {
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
|
float quantile =
|
||||||
|
this.quantile == null
|
||||||
|
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||||
|
: this.quantile;
|
||||||
|
FieldWriter quantizedWriter =
|
||||||
|
new FieldWriter(quantile, fieldInfo, segmentWriteState.infoStream, indexWriter);
|
||||||
|
fields.add(quantizedWriter);
|
||||||
|
indexWriter = quantizedWriter;
|
||||||
|
}
|
||||||
|
return rawVectorDelegate.addField(fieldInfo, indexWriter);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||||
|
// Since we know we will not be searching for additional indexing, we can just write the
|
||||||
|
// the vectors directly to the new segment.
|
||||||
|
// No need to use temporary file as we don't have to re-open for reading
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
|
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
|
||||||
|
MergedQuantizedVectorValues byteVectorValues =
|
||||||
|
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
|
||||||
|
fieldInfo, mergeState, mergedQuantizationState);
|
||||||
|
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||||
|
DocsWithFieldSet docsWithField =
|
||||||
|
writeQuantizedVectorData(quantizedVectorData, byteVectorValues);
|
||||||
|
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
float quantile =
|
||||||
|
this.quantile == null
|
||||||
|
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||||
|
: this.quantile;
|
||||||
|
writeMeta(
|
||||||
|
fieldInfo,
|
||||||
|
segmentWriteState.segmentInfo.maxDoc(),
|
||||||
|
vectorDataOffset,
|
||||||
|
vectorDataLength,
|
||||||
|
quantile,
|
||||||
|
mergedQuantizationState.getLowerQuantile(),
|
||||||
|
mergedQuantizationState.getUpperQuantile(),
|
||||||
|
docsWithField);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public CloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||||
|
FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||||
|
if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||||
|
// Simply merge the underlying delegate, which just copies the raw vector data to a new
|
||||||
|
// segment file
|
||||||
|
rawVectorDelegate.mergeOneField(fieldInfo, mergeState);
|
||||||
|
ScalarQuantizer mergedQuantizationState = mergeQuantiles(fieldInfo, mergeState);
|
||||||
|
return mergeOneFieldToIndex(
|
||||||
|
segmentWriteState, fieldInfo, mergeState, mergedQuantizationState);
|
||||||
|
}
|
||||||
|
// We only merge the delegate, since the field type isn't float32, quantization wasn't
|
||||||
|
// supported, so bypass it.
|
||||||
|
return rawVectorDelegate.mergeOneFieldToIndex(fieldInfo, mergeState);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||||
|
rawVectorDelegate.flush(maxDoc, sortMap);
|
||||||
|
for (FieldWriter field : fields) {
|
||||||
|
field.finish();
|
||||||
|
if (sortMap == null) {
|
||||||
|
writeField(field, maxDoc);
|
||||||
|
} else {
|
||||||
|
writeSortingField(field, maxDoc, sortMap);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void finish() throws IOException {
|
||||||
if (finished) {
|
if (finished) {
|
||||||
throw new IllegalStateException("already finished");
|
throw new IllegalStateException("already finished");
|
||||||
}
|
}
|
||||||
finished = true;
|
finished = true;
|
||||||
|
rawVectorDelegate.finish();
|
||||||
|
if (meta != null) {
|
||||||
|
// write end of fields marker
|
||||||
|
meta.writeInt(-1);
|
||||||
|
CodecUtil.writeFooter(meta);
|
||||||
|
}
|
||||||
if (quantizedVectorData != null) {
|
if (quantizedVectorData != null) {
|
||||||
CodecUtil.writeFooter(quantizedVectorData);
|
CodecUtil.writeFooter(quantizedVectorData);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private long[] writeField(QuantizationFieldVectorWriter fieldData) throws IOException {
|
@Override
|
||||||
long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
public long ramBytesUsed() {
|
||||||
writeQuantizedVectors(fieldData);
|
long total = SHALLOW_RAM_BYTES_USED;
|
||||||
long quantizedVectorDataLength =
|
for (FieldWriter field : fields) {
|
||||||
quantizedVectorData.getFilePointer() - quantizedVectorDataOffset;
|
total += field.ramBytesUsed();
|
||||||
return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength};
|
}
|
||||||
|
return total;
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeQuantizedVectors(QuantizationFieldVectorWriter fieldData) throws IOException {
|
private void writeField(FieldWriter fieldData, int maxDoc) throws IOException {
|
||||||
|
// write vector values
|
||||||
|
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||||
|
writeQuantizedVectors(fieldData);
|
||||||
|
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
|
||||||
|
writeMeta(
|
||||||
|
fieldData.fieldInfo,
|
||||||
|
maxDoc,
|
||||||
|
vectorDataOffset,
|
||||||
|
vectorDataLength,
|
||||||
|
quantile,
|
||||||
|
fieldData.minQuantile,
|
||||||
|
fieldData.maxQuantile,
|
||||||
|
fieldData.docsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeMeta(
|
||||||
|
FieldInfo field,
|
||||||
|
int maxDoc,
|
||||||
|
long vectorDataOffset,
|
||||||
|
long vectorDataLength,
|
||||||
|
Float configuredQuantizationQuantile,
|
||||||
|
Float lowerQuantile,
|
||||||
|
Float upperQuantile,
|
||||||
|
DocsWithFieldSet docsWithField)
|
||||||
|
throws IOException {
|
||||||
|
meta.writeInt(field.number);
|
||||||
|
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||||
|
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||||
|
meta.writeVLong(vectorDataOffset);
|
||||||
|
meta.writeVLong(vectorDataLength);
|
||||||
|
meta.writeVInt(field.getVectorDimension());
|
||||||
|
int count = docsWithField.cardinality();
|
||||||
|
meta.writeInt(count);
|
||||||
|
if (count > 0) {
|
||||||
|
assert Float.isFinite(lowerQuantile) && Float.isFinite(upperQuantile);
|
||||||
|
meta.writeInt(
|
||||||
|
Float.floatToIntBits(
|
||||||
|
configuredQuantizationQuantile != null
|
||||||
|
? configuredQuantizationQuantile
|
||||||
|
: calculateDefaultQuantile(field.getVectorDimension())));
|
||||||
|
meta.writeInt(Float.floatToIntBits(lowerQuantile));
|
||||||
|
meta.writeInt(Float.floatToIntBits(upperQuantile));
|
||||||
|
}
|
||||||
|
// write docIDs
|
||||||
|
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
||||||
|
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
|
||||||
|
}
|
||||||
|
|
||||||
|
private void writeQuantizedVectors(FieldWriter fieldData) throws IOException {
|
||||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||||
byte[] vector = new byte[fieldData.dim];
|
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
|
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||||
for (float[] v : fieldData.floatVectors) {
|
for (float[] v : fieldData.floatVectors) {
|
||||||
if (fieldData.normalize) {
|
if (fieldData.normalize) {
|
||||||
System.arraycopy(v, 0, copy, 0, copy.length);
|
System.arraycopy(v, 0, copy, 0, copy.length);
|
||||||
|
@ -151,7 +303,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
}
|
}
|
||||||
|
|
||||||
float offsetCorrection =
|
float offsetCorrection =
|
||||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
|
||||||
quantizedVectorData.writeBytes(vector, vector.length);
|
quantizedVectorData.writeBytes(vector, vector.length);
|
||||||
offsetBuffer.putFloat(offsetCorrection);
|
offsetBuffer.putFloat(offsetCorrection);
|
||||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||||
|
@ -159,14 +311,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private long[] writeSortingField(
|
private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocMap sortMap)
|
||||||
QuantizationFieldVectorWriter fieldData,
|
|
||||||
Sorter.DocMap sortMap,
|
|
||||||
DocsWithFieldSet docsWithField)
|
|
||||||
throws IOException {
|
throws IOException {
|
||||||
final int[] docIdOffsets = new int[sortMap.size()];
|
final int[] docIdOffsets = new int[sortMap.size()];
|
||||||
int offset = 1; // 0 means no vector for this (field, document)
|
int offset = 1; // 0 means no vector for this (field, document)
|
||||||
DocIdSetIterator iterator = docsWithField.iterator();
|
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||||
for (int docID = iterator.nextDoc();
|
for (int docID = iterator.nextDoc();
|
||||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||||
docID = iterator.nextDoc()) {
|
docID = iterator.nextDoc()) {
|
||||||
|
@ -175,13 +324,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
}
|
}
|
||||||
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||||
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
||||||
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
|
|
||||||
int ord = 0;
|
int ord = 0;
|
||||||
int doc = 0;
|
int doc = 0;
|
||||||
for (int docIdOffset : docIdOffsets) {
|
for (int docIdOffset : docIdOffsets) {
|
||||||
if (docIdOffset != 0) {
|
if (docIdOffset != 0) {
|
||||||
ordMap[ord] = docIdOffset - 1;
|
ordMap[ord] = docIdOffset - 1;
|
||||||
oldOrdMap[docIdOffset - 1] = ord;
|
|
||||||
newDocsWithField.add(doc);
|
newDocsWithField.add(doc);
|
||||||
ord++;
|
ord++;
|
||||||
}
|
}
|
||||||
|
@ -192,16 +339,22 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||||
writeSortedQuantizedVectors(fieldData, ordMap);
|
writeSortedQuantizedVectors(fieldData, ordMap);
|
||||||
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||||
|
writeMeta(
|
||||||
return new long[] {vectorDataOffset, quantizedVectorLength};
|
fieldData.fieldInfo,
|
||||||
|
maxDoc,
|
||||||
|
vectorDataOffset,
|
||||||
|
quantizedVectorLength,
|
||||||
|
quantile,
|
||||||
|
fieldData.minQuantile,
|
||||||
|
fieldData.maxQuantile,
|
||||||
|
newDocsWithField);
|
||||||
}
|
}
|
||||||
|
|
||||||
void writeSortedQuantizedVectors(QuantizationFieldVectorWriter fieldData, int[] ordMap)
|
private void writeSortedQuantizedVectors(FieldWriter fieldData, int[] ordMap) throws IOException {
|
||||||
throws IOException {
|
|
||||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||||
byte[] vector = new byte[fieldData.dim];
|
byte[] vector = new byte[fieldData.fieldInfo.getVectorDimension()];
|
||||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
|
float[] copy = fieldData.normalize ? new float[fieldData.fieldInfo.getVectorDimension()] : null;
|
||||||
for (int ordinal : ordMap) {
|
for (int ordinal : ordMap) {
|
||||||
float[] v = fieldData.floatVectors.get(ordinal);
|
float[] v = fieldData.floatVectors.get(ordinal);
|
||||||
if (fieldData.normalize) {
|
if (fieldData.normalize) {
|
||||||
|
@ -209,9 +362,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
VectorUtil.l2normalize(copy);
|
VectorUtil.l2normalize(copy);
|
||||||
v = copy;
|
v = copy;
|
||||||
}
|
}
|
||||||
|
|
||||||
float offsetCorrection =
|
float offsetCorrection =
|
||||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
scalarQuantizer.quantize(v, vector, fieldData.fieldInfo.getVectorSimilarityFunction());
|
||||||
quantizedVectorData.writeBytes(vector, vector.length);
|
quantizedVectorData.writeBytes(vector, vector.length);
|
||||||
offsetBuffer.putFloat(offsetCorrection);
|
offsetBuffer.putFloat(offsetCorrection);
|
||||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||||
|
@ -219,10 +371,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
private ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState)
|
||||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
throws IOException {
|
||||||
return null;
|
assert fieldInfo.getVectorEncoding() == VectorEncoding.FLOAT32;
|
||||||
}
|
|
||||||
float quantile =
|
float quantile =
|
||||||
this.quantile == null
|
this.quantile == null
|
||||||
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||||
|
@ -230,15 +381,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile);
|
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile);
|
||||||
}
|
}
|
||||||
|
|
||||||
ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField(
|
private ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneFieldToIndex(
|
||||||
SegmentWriteState segmentWriteState,
|
SegmentWriteState segmentWriteState,
|
||||||
FieldInfo fieldInfo,
|
FieldInfo fieldInfo,
|
||||||
MergeState mergeState,
|
MergeState mergeState,
|
||||||
ScalarQuantizer mergedQuantizationState)
|
ScalarQuantizer mergedQuantizationState)
|
||||||
throws IOException {
|
throws IOException {
|
||||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||||
return null;
|
|
||||||
}
|
|
||||||
IndexOutput tempQuantizedVectorData =
|
IndexOutput tempQuantizedVectorData =
|
||||||
segmentWriteState.directory.createTempOutput(
|
segmentWriteState.directory.createTempOutput(
|
||||||
quantizedVectorData.getName(), "temp", segmentWriteState.context);
|
quantizedVectorData.getName(), "temp", segmentWriteState.context);
|
||||||
|
@ -257,7 +406,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
tempQuantizedVectorData.getName(), segmentWriteState.context);
|
tempQuantizedVectorData.getName(), segmentWriteState.context);
|
||||||
quantizedVectorData.copyBytes(
|
quantizedVectorData.copyBytes(
|
||||||
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
|
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
|
||||||
|
long vectorDataLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||||
CodecUtil.retrieveChecksum(quantizationDataInput);
|
CodecUtil.retrieveChecksum(quantizationDataInput);
|
||||||
|
float quantile =
|
||||||
|
this.quantile == null
|
||||||
|
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||||
|
: this.quantile;
|
||||||
|
writeMeta(
|
||||||
|
fieldInfo,
|
||||||
|
segmentWriteState.segmentInfo.maxDoc(),
|
||||||
|
vectorDataOffset,
|
||||||
|
vectorDataLength,
|
||||||
|
quantile,
|
||||||
|
mergedQuantizationState.getLowerQuantile(),
|
||||||
|
mergedQuantizationState.getUpperQuantile(),
|
||||||
|
docsWithField);
|
||||||
success = true;
|
success = true;
|
||||||
final IndexInput finalQuantizationDataInput = quantizationDataInput;
|
final IndexInput finalQuantizationDataInput = quantizationDataInput;
|
||||||
return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
||||||
|
@ -265,6 +428,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
IOUtils.close(finalQuantizationDataInput);
|
IOUtils.close(finalQuantizationDataInput);
|
||||||
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
|
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
|
||||||
},
|
},
|
||||||
|
docsWithField.cardinality(),
|
||||||
new ScalarQuantizedRandomVectorScorerSupplier(
|
new ScalarQuantizedRandomVectorScorerSupplier(
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
fieldInfo.getVectorSimilarityFunction(),
|
||||||
mergedQuantizationState,
|
mergedQuantizationState,
|
||||||
|
@ -427,43 +591,35 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public void close() throws IOException {
|
||||||
return BASE_RAM_BYTES_USED;
|
IOUtils.close(meta, quantizedVectorData, rawVectorDelegate);
|
||||||
}
|
}
|
||||||
|
|
||||||
static class QuantizationFieldVectorWriter implements Accountable {
|
static class FieldWriter extends FlatFieldVectorsWriter<float[]> {
|
||||||
private static final long SHALLOW_SIZE =
|
private static final long SHALLOW_SIZE = shallowSizeOfInstance(FieldWriter.class);
|
||||||
shallowSizeOfInstance(QuantizationFieldVectorWriter.class);
|
|
||||||
private final int dim;
|
|
||||||
private final List<float[]> floatVectors;
|
private final List<float[]> floatVectors;
|
||||||
private final boolean normalize;
|
private final FieldInfo fieldInfo;
|
||||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
|
||||||
private final float quantile;
|
private final float quantile;
|
||||||
private final InfoStream infoStream;
|
private final InfoStream infoStream;
|
||||||
|
private final boolean normalize;
|
||||||
private float minQuantile = Float.POSITIVE_INFINITY;
|
private float minQuantile = Float.POSITIVE_INFINITY;
|
||||||
private float maxQuantile = Float.NEGATIVE_INFINITY;
|
private float maxQuantile = Float.NEGATIVE_INFINITY;
|
||||||
private boolean finished;
|
private boolean finished;
|
||||||
|
private final DocsWithFieldSet docsWithField;
|
||||||
|
|
||||||
static QuantizationFieldVectorWriter create(
|
@SuppressWarnings("unchecked")
|
||||||
FieldInfo fieldInfo, float quantile, InfoStream infoStream) {
|
FieldWriter(
|
||||||
return new QuantizationFieldVectorWriter(
|
|
||||||
fieldInfo.getVectorDimension(),
|
|
||||||
quantile,
|
|
||||||
fieldInfo.getVectorSimilarityFunction(),
|
|
||||||
infoStream);
|
|
||||||
}
|
|
||||||
|
|
||||||
QuantizationFieldVectorWriter(
|
|
||||||
int dim,
|
|
||||||
float quantile,
|
float quantile,
|
||||||
VectorSimilarityFunction vectorSimilarityFunction,
|
FieldInfo fieldInfo,
|
||||||
InfoStream infoStream) {
|
InfoStream infoStream,
|
||||||
this.dim = dim;
|
KnnFieldVectorsWriter<?> indexWriter) {
|
||||||
|
super((KnnFieldVectorsWriter<float[]>) indexWriter);
|
||||||
this.quantile = quantile;
|
this.quantile = quantile;
|
||||||
this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE;
|
this.fieldInfo = fieldInfo;
|
||||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
this.normalize = fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE;
|
||||||
this.floatVectors = new ArrayList<>();
|
this.floatVectors = new ArrayList<>();
|
||||||
this.infoStream = infoStream;
|
this.infoStream = infoStream;
|
||||||
|
this.docsWithField = new DocsWithFieldSet();
|
||||||
}
|
}
|
||||||
|
|
||||||
void finish() throws IOException {
|
void finish() throws IOException {
|
||||||
|
@ -475,15 +631,17 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ScalarQuantizer quantizer =
|
ScalarQuantizer quantizer =
|
||||||
ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile);
|
ScalarQuantizer.fromVectors(
|
||||||
|
new FloatVectorWrapper(
|
||||||
|
floatVectors,
|
||||||
|
fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE),
|
||||||
|
quantile);
|
||||||
minQuantile = quantizer.getLowerQuantile();
|
minQuantile = quantizer.getLowerQuantile();
|
||||||
maxQuantile = quantizer.getUpperQuantile();
|
maxQuantile = quantizer.getUpperQuantile();
|
||||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||||
infoStream.message(
|
infoStream.message(
|
||||||
QUANTIZED_VECTOR_COMPONENT,
|
QUANTIZED_VECTOR_COMPONENT,
|
||||||
"quantized field="
|
"quantized field="
|
||||||
+ " dimension="
|
|
||||||
+ dim
|
|
||||||
+ " quantile="
|
+ " quantile="
|
||||||
+ quantile
|
+ quantile
|
||||||
+ " minQuantile="
|
+ " minQuantile="
|
||||||
|
@ -494,24 +652,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
finished = true;
|
finished = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void addValue(float[] vectorValue) throws IOException {
|
|
||||||
floatVectors.add(vectorValue);
|
|
||||||
}
|
|
||||||
|
|
||||||
float getMinQuantile() {
|
|
||||||
assert finished;
|
|
||||||
return minQuantile;
|
|
||||||
}
|
|
||||||
|
|
||||||
float getMaxQuantile() {
|
|
||||||
assert finished;
|
|
||||||
return maxQuantile;
|
|
||||||
}
|
|
||||||
|
|
||||||
float getQuantile() {
|
|
||||||
return quantile;
|
|
||||||
}
|
|
||||||
|
|
||||||
ScalarQuantizer createQuantizer() {
|
ScalarQuantizer createQuantizer() {
|
||||||
assert finished;
|
assert finished;
|
||||||
return new ScalarQuantizer(minQuantile, maxQuantile, quantile);
|
return new ScalarQuantizer(minQuantile, maxQuantile, quantile);
|
||||||
|
@ -519,8 +659,26 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public long ramBytesUsed() {
|
public long ramBytesUsed() {
|
||||||
if (floatVectors.size() == 0) return SHALLOW_SIZE;
|
long size = SHALLOW_SIZE;
|
||||||
return SHALLOW_SIZE + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
if (indexingDelegate != null) {
|
||||||
|
size += indexingDelegate.ramBytesUsed();
|
||||||
|
}
|
||||||
|
if (floatVectors.size() == 0) return size;
|
||||||
|
return size + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void addValue(int docID, float[] vectorValue) throws IOException {
|
||||||
|
docsWithField.add(docID);
|
||||||
|
floatVectors.add(vectorValue);
|
||||||
|
if (indexingDelegate != null) {
|
||||||
|
indexingDelegate.addValue(docID, vectorValue);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public float[] copyValue(float[] vectorValue) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -613,6 +771,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
// Either our quantization parameters are way different than the merged ones
|
// Either our quantization parameters are way different than the merged ones
|
||||||
// Or we have never been quantized.
|
// Or we have never been quantized.
|
||||||
if (reader == null
|
if (reader == null
|
||||||
|
|| reader.getQuantizationState(fieldInfo.name) == null
|
||||||
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
|
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
|
||||||
sub =
|
sub =
|
||||||
new QuantizedByteVectorValueSub(
|
new QuantizedByteVectorValueSub(
|
||||||
|
@ -702,6 +861,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
private final FloatVectorValues values;
|
private final FloatVectorValues values;
|
||||||
private final ScalarQuantizer quantizer;
|
private final ScalarQuantizer quantizer;
|
||||||
private final byte[] quantizedVector;
|
private final byte[] quantizedVector;
|
||||||
|
private final float[] normalizedVector;
|
||||||
private float offsetValue = 0f;
|
private float offsetValue = 0f;
|
||||||
|
|
||||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||||
|
@ -714,6 +874,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
this.quantizer = quantizer;
|
this.quantizer = quantizer;
|
||||||
this.quantizedVector = new byte[values.dimension()];
|
this.quantizedVector = new byte[values.dimension()];
|
||||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||||
|
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||||
|
this.normalizedVector = new float[values.dimension()];
|
||||||
|
} else {
|
||||||
|
this.normalizedVector = null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -745,8 +910,7 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
public int nextDoc() throws IOException {
|
public int nextDoc() throws IOException {
|
||||||
int doc = values.nextDoc();
|
int doc = values.nextDoc();
|
||||||
if (doc != NO_MORE_DOCS) {
|
if (doc != NO_MORE_DOCS) {
|
||||||
offsetValue =
|
quantize();
|
||||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
|
||||||
}
|
}
|
||||||
return doc;
|
return doc;
|
||||||
}
|
}
|
||||||
|
@ -755,10 +919,21 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
public int advance(int target) throws IOException {
|
public int advance(int target) throws IOException {
|
||||||
int doc = values.advance(target);
|
int doc = values.advance(target);
|
||||||
if (doc != NO_MORE_DOCS) {
|
if (doc != NO_MORE_DOCS) {
|
||||||
|
quantize();
|
||||||
|
}
|
||||||
|
return doc;
|
||||||
|
}
|
||||||
|
|
||||||
|
private void quantize() throws IOException {
|
||||||
|
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||||
|
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||||
|
VectorUtil.l2normalize(normalizedVector);
|
||||||
|
offsetValue =
|
||||||
|
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
|
||||||
|
} else {
|
||||||
offsetValue =
|
offsetValue =
|
||||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||||
}
|
}
|
||||||
return doc;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -767,11 +942,13 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
|
|
||||||
private final ScalarQuantizedRandomVectorScorerSupplier supplier;
|
private final ScalarQuantizedRandomVectorScorerSupplier supplier;
|
||||||
private final Closeable onClose;
|
private final Closeable onClose;
|
||||||
|
private final int numVectors;
|
||||||
|
|
||||||
ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
||||||
Closeable onClose, ScalarQuantizedRandomVectorScorerSupplier supplier) {
|
Closeable onClose, int numVectors, ScalarQuantizedRandomVectorScorerSupplier supplier) {
|
||||||
this.onClose = onClose;
|
this.onClose = onClose;
|
||||||
this.supplier = supplier;
|
this.supplier = supplier;
|
||||||
|
this.numVectors = numVectors;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -788,6 +965,11 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
public void close() throws IOException {
|
public void close() throws IOException {
|
||||||
onClose.close();
|
onClose.close();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int totalVectorCount() {
|
||||||
|
return numVectors;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static final class OffsetCorrectedQuantizedByteVectorValues
|
private static final class OffsetCorrectedQuantizedByteVectorValues
|
||||||
|
|
|
@ -98,8 +98,6 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
|
||||||
|
|
||||||
static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||||
|
|
||||||
private int doc = -1;
|
private int doc = -1;
|
||||||
|
@ -138,7 +136,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return acceptDocs;
|
return acceptDocs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -196,7 +194,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
if (acceptDocs == null) {
|
if (acceptDocs == null) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
@ -268,7 +266,7 @@ abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValue
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
Bits getAcceptOrds(Bits acceptDocs) {
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,8 @@ import org.apache.lucene.util.VectorUtil;
|
||||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||||
|
|
||||||
/** Quantized vector scorer */
|
/** Quantized vector scorer */
|
||||||
final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
|
final class ScalarQuantizedRandomVectorScorer
|
||||||
|
extends RandomVectorScorer.AbstractRandomVectorScorer<byte[]> {
|
||||||
|
|
||||||
private static float quantizeQuery(
|
private static float quantizeQuery(
|
||||||
float[] query,
|
float[] query,
|
||||||
|
@ -54,6 +55,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
|
||||||
RandomAccessQuantizedByteVectorValues values,
|
RandomAccessQuantizedByteVectorValues values,
|
||||||
byte[] query,
|
byte[] query,
|
||||||
float queryOffset) {
|
float queryOffset) {
|
||||||
|
super(values);
|
||||||
this.quantizedQuery = query;
|
this.quantizedQuery = query;
|
||||||
this.queryOffset = queryOffset;
|
this.queryOffset = queryOffset;
|
||||||
this.similarity = similarityFunction;
|
this.similarity = similarityFunction;
|
||||||
|
@ -65,6 +67,7 @@ final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
|
||||||
ScalarQuantizer scalarQuantizer,
|
ScalarQuantizer scalarQuantizer,
|
||||||
RandomAccessQuantizedByteVectorValues values,
|
RandomAccessQuantizedByteVectorValues values,
|
||||||
float[] query) {
|
float[] query) {
|
||||||
|
super(values);
|
||||||
byte[] quantizedQuery = new byte[query.length];
|
byte[] quantizedQuery = new byte[query.length];
|
||||||
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
|
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
|
||||||
this.quantizedQuery = quantizedQuery;
|
this.quantizedQuery = quantizedQuery;
|
||||||
|
|
|
@ -26,5 +26,6 @@ import java.io.Closeable;
|
||||||
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
|
* <p>NOTE: the {@link #copy()} returned {@link RandomVectorScorerSupplier} is not necessarily
|
||||||
* closeable
|
* closeable
|
||||||
*/
|
*/
|
||||||
public interface CloseableRandomVectorScorerSupplier
|
public interface CloseableRandomVectorScorerSupplier extends Closeable, RandomVectorScorerSupplier {
|
||||||
extends Closeable, RandomVectorScorerSupplier {}
|
int totalVectorCount();
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
package org.apache.lucene.util.hnsw;
|
package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
* Provides random access to vectors by dense ordinal. This interface is used by HNSW-based
|
||||||
|
@ -56,4 +57,14 @@ public interface RandomAccessVectorValues<T> {
|
||||||
default int ordToDoc(int ord) {
|
default int ordToDoc(int ord) {
|
||||||
return ord;
|
return ord;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
|
||||||
|
*
|
||||||
|
* @param acceptDocs the accept docs
|
||||||
|
* @return the accept docs
|
||||||
|
*/
|
||||||
|
default Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return acceptDocs;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
|
/** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */
|
||||||
public interface RandomVectorScorer {
|
public interface RandomVectorScorer {
|
||||||
|
@ -30,6 +31,31 @@ public interface RandomVectorScorer {
|
||||||
*/
|
*/
|
||||||
float score(int node) throws IOException;
|
float score(int node) throws IOException;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return the maximum possible ordinal for this scorer
|
||||||
|
*/
|
||||||
|
int maxOrd();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Translates vector ordinal to the correct document ID. By default, this is an identity function.
|
||||||
|
*
|
||||||
|
* @param ord the vector ordinal
|
||||||
|
* @return the document Id for that vector ordinal
|
||||||
|
*/
|
||||||
|
default int ordToDoc(int ord) {
|
||||||
|
return ord;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the {@link Bits} representing live documents. By default, this is an identity function.
|
||||||
|
*
|
||||||
|
* @param acceptDocs the accept docs
|
||||||
|
* @return the accept docs
|
||||||
|
*/
|
||||||
|
default Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return acceptDocs;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Creates a default scorer for float vectors.
|
* Creates a default scorer for float vectors.
|
||||||
*
|
*
|
||||||
|
@ -53,7 +79,12 @@ public interface RandomVectorScorer {
|
||||||
+ " differs from field dimension: "
|
+ " differs from field dimension: "
|
||||||
+ vectors.dimension());
|
+ vectors.dimension());
|
||||||
}
|
}
|
||||||
return node -> similarityFunction.compare(query, vectors.vectorValue(node));
|
return new AbstractRandomVectorScorer<>(vectors) {
|
||||||
|
@Override
|
||||||
|
public float score(int node) throws IOException {
|
||||||
|
return similarityFunction.compare(query, vectors.vectorValue(node));
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -79,6 +110,44 @@ public interface RandomVectorScorer {
|
||||||
+ " differs from field dimension: "
|
+ " differs from field dimension: "
|
||||||
+ vectors.dimension());
|
+ vectors.dimension());
|
||||||
}
|
}
|
||||||
return node -> similarityFunction.compare(query, vectors.vectorValue(node));
|
return new AbstractRandomVectorScorer<>(vectors) {
|
||||||
|
@Override
|
||||||
|
public float score(int node) throws IOException {
|
||||||
|
return similarityFunction.compare(query, vectors.vectorValue(node));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a default scorer for random access vectors.
|
||||||
|
*
|
||||||
|
* @param <T> the type of the vector values
|
||||||
|
*/
|
||||||
|
abstract class AbstractRandomVectorScorer<T> implements RandomVectorScorer {
|
||||||
|
private final RandomAccessVectorValues<T> values;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a new scorer for the given vector values.
|
||||||
|
*
|
||||||
|
* @param values the vector values
|
||||||
|
*/
|
||||||
|
public AbstractRandomVectorScorer(RandomAccessVectorValues<T> values) {
|
||||||
|
this.values = values;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int maxOrd() {
|
||||||
|
return values.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public int ordToDoc(int ord) {
|
||||||
|
return values.ordToDoc(ord);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
return values.getAcceptOrds(acceptDocs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -87,8 +87,12 @@ public interface RandomVectorScorerSupplier {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||||
return cand ->
|
return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
|
||||||
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
@Override
|
||||||
|
public float score(int cand) throws IOException {
|
||||||
|
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -115,8 +119,12 @@ public interface RandomVectorScorerSupplier {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||||
return cand ->
|
return new RandomVectorScorer.AbstractRandomVectorScorer<>(vectors) {
|
||||||
similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
@Override
|
||||||
|
public float score(int cand) throws IOException {
|
||||||
|
return similarityFunction.compare(vectors1.vectorValue(ord), vectors2.vectorValue(cand));
|
||||||
|
}
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
|
|
@ -14,3 +14,4 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
|
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
|
||||||
|
org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat
|
||||||
|
|
|
@ -47,10 +47,9 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
||||||
return new Lucene99Codec() {
|
return new Lucene99Codec() {
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
return new Lucene99HnswVectorsFormat(
|
return new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||||
new Lucene99ScalarQuantizedVectorsFormat());
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -145,12 +144,11 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
||||||
new FilterCodec("foo", Codec.getDefault()) {
|
new FilterCodec("foo", Codec.getDefault()) {
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat knnVectorsFormat() {
|
public KnnVectorsFormat knnVectorsFormat() {
|
||||||
return new Lucene99HnswVectorsFormat(
|
return new Lucene99HnswScalarQuantizedVectorsFormat(10, 20, 1, 0.9f, null);
|
||||||
10, 20, new Lucene99ScalarQuantizedVectorsFormat(0.9f));
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
String expectedString =
|
String expectedString =
|
||||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9))";
|
"Lucene99HnswScalarQuantizedVectorsFormat(name=Lucene99HnswScalarQuantizedVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9, rawVectorFormat=Lucene99FlatVectorsFormat()))";
|
||||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
String expectedString =
|
String expectedString =
|
||||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=none)";
|
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, flatVectorFormat=Lucene99FlatVectorsFormat())";
|
||||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
package org.apache.lucene.codecs.lucene99;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
|
||||||
import org.apache.lucene.util.ScalarQuantizer;
|
|
||||||
|
|
||||||
public class TestLucene99ScalarQuantizedVectorsFormat extends LuceneTestCase {
|
|
||||||
|
|
||||||
public void testDefaultQuantile() {
|
|
||||||
float defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(99);
|
|
||||||
assertEquals(0.99f, defaultQuantile, 1e-5);
|
|
||||||
defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(1);
|
|
||||||
assertEquals(0.9f, defaultQuantile, 1e-5);
|
|
||||||
defaultQuantile =
|
|
||||||
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(Integer.MAX_VALUE - 2);
|
|
||||||
assertEquals(1.0f, defaultQuantile, 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testLimits() {
|
|
||||||
expectThrows(
|
|
||||||
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(0.89f));
|
|
||||||
expectThrows(
|
|
||||||
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(1.1f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testQuantileMergeWithMissing() {
|
|
||||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
|
||||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
|
||||||
quantiles.add(null);
|
|
||||||
List<Integer> segmentSizes = List.of(1, 1, 1, 1);
|
|
||||||
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f));
|
|
||||||
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(List.of(), List.of(), 0.1f));
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testQuantileMerge() {
|
|
||||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
|
||||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
|
||||||
List<Integer> segmentSizes = List.of(1, 1, 1);
|
|
||||||
ScalarQuantizer merged =
|
|
||||||
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
|
|
||||||
assertEquals(0.2f, merged.getLowerQuantile(), 1e-5);
|
|
||||||
assertEquals(0.3f, merged.getUpperQuantile(), 1e-5);
|
|
||||||
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void testQuantileMergeWithDifferentSegmentSizes() {
|
|
||||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
|
||||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
|
||||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
|
||||||
List<Integer> segmentSizes = List.of(1, 2, 3);
|
|
||||||
ScalarQuantizer merged =
|
|
||||||
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
|
|
||||||
assertEquals(0.2333333f, merged.getLowerQuantile(), 1e-5);
|
|
||||||
assertEquals(0.3333333f, merged.getUpperQuantile(), 1e-5);
|
|
||||||
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -32,9 +32,9 @@ import java.util.concurrent.CountDownLatch;
|
||||||
import org.apache.lucene.codecs.Codec;
|
import org.apache.lucene.codecs.Codec;
|
||||||
import org.apache.lucene.codecs.FilterCodec;
|
import org.apache.lucene.codecs.FilterCodec;
|
||||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||||
|
import org.apache.lucene.codecs.lucene99.Lucene99HnswScalarQuantizedVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
||||||
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
|
|
||||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||||
import org.apache.lucene.document.Document;
|
import org.apache.lucene.document.Document;
|
||||||
import org.apache.lucene.document.Field;
|
import org.apache.lucene.document.Field;
|
||||||
|
@ -83,12 +83,7 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
||||||
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
||||||
vectorEncoding = randomVectorEncoding();
|
vectorEncoding = randomVectorEncoding();
|
||||||
|
boolean quantized = randomBoolean();
|
||||||
Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat =
|
|
||||||
vectorEncoding.equals(VectorEncoding.FLOAT32) && randomBoolean()
|
|
||||||
? new Lucene99ScalarQuantizedVectorsFormat(1f)
|
|
||||||
: null;
|
|
||||||
|
|
||||||
codec =
|
codec =
|
||||||
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
|
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
|
||||||
@Override
|
@Override
|
||||||
|
@ -96,14 +91,16 @@ public class TestKnnGraph extends LuceneTestCase {
|
||||||
return new PerFieldKnnVectorsFormat() {
|
return new PerFieldKnnVectorsFormat() {
|
||||||
@Override
|
@Override
|
||||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||||
return new Lucene99HnswVectorsFormat(
|
return quantized
|
||||||
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, scalarQuantizedVectorsFormat);
|
? new Lucene99HnswScalarQuantizedVectorsFormat(
|
||||||
|
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH)
|
||||||
|
: new Lucene99HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if (vectorEncoding == VectorEncoding.FLOAT32 && scalarQuantizedVectorsFormat == null) {
|
if (vectorEncoding == VectorEncoding.FLOAT32) {
|
||||||
float32Codec = codec;
|
float32Codec = codec;
|
||||||
} else {
|
} else {
|
||||||
float32Codec =
|
float32Codec =
|
||||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.util.hnsw;
|
||||||
|
|
||||||
import java.io.IOException;
|
import java.io.IOException;
|
||||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||||
|
import org.apache.lucene.util.Bits;
|
||||||
|
|
||||||
public class TestNeighborArray extends LuceneTestCase {
|
public class TestNeighborArray extends LuceneTestCase {
|
||||||
|
|
||||||
|
@ -190,7 +191,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
||||||
neighbors.addOutOfOrder(7, Float.NaN);
|
neighbors.addOutOfOrder(7, Float.NaN);
|
||||||
neighbors.addOutOfOrder(6, Float.NaN);
|
neighbors.addOutOfOrder(6, Float.NaN);
|
||||||
neighbors.addOutOfOrder(4, Float.NaN);
|
neighbors.addOutOfOrder(4, Float.NaN);
|
||||||
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 1);
|
int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 1);
|
||||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||||
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
|
assertNodesEqual(new int[] {1, 2, 3, 4, 5, 6, 7}, neighbors);
|
||||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
||||||
|
@ -206,7 +207,7 @@ public class TestNeighborArray extends LuceneTestCase {
|
||||||
neighbors.addOutOfOrder(17, Float.NaN);
|
neighbors.addOutOfOrder(17, Float.NaN);
|
||||||
neighbors.addOutOfOrder(16, Float.NaN);
|
neighbors.addOutOfOrder(16, Float.NaN);
|
||||||
neighbors.addOutOfOrder(14, Float.NaN);
|
neighbors.addOutOfOrder(14, Float.NaN);
|
||||||
int[] unchecked = neighbors.sort(nodeId -> 7 - nodeId + 11);
|
int[] unchecked = neighbors.sort((TestRandomVectorScorer) nodeId -> 7 - nodeId + 11);
|
||||||
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
assertArrayEquals(new int[] {0, 1, 2, 3, 4, 5, 6}, unchecked);
|
||||||
assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors);
|
assertNodesEqual(new int[] {11, 12, 13, 14, 15, 16, 17}, neighbors);
|
||||||
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
assertScoresEqual(new float[] {7, 6, 5, 4, 3, 2, 1}, neighbors);
|
||||||
|
@ -223,4 +224,21 @@ public class TestNeighborArray extends LuceneTestCase {
|
||||||
assertEquals(nodes[i], neighbors.node[i]);
|
assertEquals(nodes[i], neighbors.node[i]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
interface TestRandomVectorScorer extends RandomVectorScorer {
|
||||||
|
@Override
|
||||||
|
default int maxOrd() {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
default int ordToDoc(int ord) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
default Bits getAcceptOrds(Bits acceptDocs) {
|
||||||
|
throw new UnsupportedOperationException();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue