Add new int8 scalar quantization to HNSW codec (#12582)

Adds new int8 scalar quantization for HNSW codec. This uses a new lucene9.9 format and auto quantizes floating point vectors into bytes on flush and merge.
This commit is contained in:
Benjamin Trent 2023-10-24 14:31:54 -04:00 committed by GitHub
parent e5b55761d0
commit f2bf5339e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 4682 additions and 90 deletions

View File

@ -47,7 +47,8 @@ module org.apache.lucene.backward_codecs {
org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsFormat, org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsFormat,
org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat, org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat,
org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat, org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat,
org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat; org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat,
org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat;
provides org.apache.lucene.codecs.Codec with provides org.apache.lucene.codecs.Codec with
org.apache.lucene.backward_codecs.lucene80.Lucene80Codec, org.apache.lucene.backward_codecs.lucene80.Lucene80Codec,
org.apache.lucene.backward_codecs.lucene84.Lucene84Codec, org.apache.lucene.backward_codecs.lucene84.Lucene84Codec,

View File

@ -40,7 +40,6 @@ import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat; import org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat; import org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat;
import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat; import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
@ -145,7 +144,7 @@ public class Lucene95Codec extends Codec {
} }
@Override @Override
public final SegmentInfoFormat segmentInfoFormat() { public SegmentInfoFormat segmentInfoFormat() {
return segmentInfosFormat; return segmentInfosFormat;
} }
@ -165,7 +164,7 @@ public class Lucene95Codec extends Codec {
} }
@Override @Override
public final KnnVectorsFormat knnVectorsFormat() { public KnnVectorsFormat knnVectorsFormat() {
return knnVectorsFormat; return knnVectorsFormat;
} }

View File

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.lucene.codecs.lucene95; package org.apache.lucene.backward_codecs.lucene95;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
@ -96,7 +96,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
* *
* @lucene.experimental * @lucene.experimental
*/ */
public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat { public class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
static final String META_CODEC_NAME = "Lucene95HnswVectorsFormatMeta"; static final String META_CODEC_NAME = "Lucene95HnswVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene95HnswVectorsFormatData"; static final String VECTOR_DATA_CODEC_NAME = "Lucene95HnswVectorsFormatData";
@ -105,8 +105,8 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
static final String VECTOR_DATA_EXTENSION = "vec"; static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex"; static final String VECTOR_INDEX_EXTENSION = "vex";
public static final int VERSION_START = 0; static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START; static final int VERSION_CURRENT = 1;
/** /**
* A maximum configurable maximum max conn. * A maximum configurable maximum max conn.
@ -137,14 +137,14 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to * Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene95HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details. * {@link Lucene95HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/ */
private final int maxConn; final int maxConn;
/** /**
* The number of candidate neighbors to track while searching the graph for each newly inserted * The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to to {@link Lucene95HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link * node. Defaults to to {@link Lucene95HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
* HnswGraph} for details. * HnswGraph} for details.
*/ */
private final int beamWidth; final int beamWidth;
/** Constructs a format using default graph construction parameters */ /** Constructs a format using default graph construction parameters */
public Lucene95HnswVectorsFormat() { public Lucene95HnswVectorsFormat() {
@ -179,7 +179,7 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
@Override @Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene95HnswVectorsWriter(state, maxConn, beamWidth); throw new UnsupportedOperationException("Old codecs may only be used for reading");
} }
@Override @Override

View File

@ -15,7 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.lucene.codecs.lucene95; package org.apache.lucene.backward_codecs.lucene95;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
@ -26,6 +26,9 @@ import java.util.Map;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider; import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException; import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfo;

View File

@ -180,8 +180,8 @@
* of files, recording dimensionally indexed fields, to enable fast numeric range filtering * of files, recording dimensionally indexed fields, to enable fast numeric range filtering
* and large numeric values like BigInteger and BigDecimal (1D) and geographic shape * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
* intersection (2D, 3D). * intersection (2D, 3D).
* <li>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}. The * <li>{@link org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat Vector values}.
* vector format stores numeric vectors in a format optimized for random access and * The vector format stores numeric vectors in a format optimized for random access and
* computation, supporting high-dimensional nearest-neighbor search. * computation, supporting high-dimensional nearest-neighbor search.
* </ul> * </ul>
* *
@ -310,7 +310,7 @@
* <td>Holds indexed points</td> * <td>Holds indexed points</td>
* </tr> * </tr>
* <tr> * <tr>
* <td>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td> * <td>{@link org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td>
* <td>.vec, .vem</td> * <td>.vec, .vem</td>
* <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data, and * <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data, and
* <code>.vem</code> the vector metadata</td> * <code>.vem</code> the vector metadata</td>

View File

@ -17,3 +17,4 @@ org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsFormat
org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat
org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat
org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat
org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat

View File

@ -15,9 +15,9 @@
* limitations under the License. * limitations under the License.
*/ */
package org.apache.lucene.codecs.lucene95; package org.apache.lucene.backward_codecs.lucene95;
import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException; import java.io.IOException;
@ -29,14 +29,33 @@ import java.util.List;
import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.*; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter; import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.*; import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.hnsw.*; import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator; import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter; import org.apache.lucene.util.packed.DirectMonotonicWriter;
/** /**

View File

@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene95;
import org.apache.lucene.backward_codecs.lucene90.Lucene90RWSegmentInfoFormat;
import org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.SegmentInfoFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
/** Implements the Lucene 9.5 index format for backwards compat testing */
public class Lucene95RWCodec extends Lucene95Codec {
private final KnnVectorsFormat defaultKnnVectorsFormat;
private final KnnVectorsFormat knnVectorsFormat =
new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return defaultKnnVectorsFormat;
}
};
private final SegmentInfoFormat segmentInfosFormat = new Lucene90RWSegmentInfoFormat();
/** Instantiates a new codec. */
public Lucene95RWCodec() {
defaultKnnVectorsFormat =
new Lucene95RWHnswVectorsFormat(
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
}
@Override
public final KnnVectorsFormat knnVectorsFormat() {
return knnVectorsFormat;
}
@Override
public SegmentInfoFormat segmentInfoFormat() {
return segmentInfosFormat;
}
}

View File

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene95;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
public final class Lucene95RWHnswVectorsFormat extends Lucene95HnswVectorsFormat {
public Lucene95RWHnswVectorsFormat(int maxConn, int beamWidth) {
super(maxConn, beamWidth);
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene95HnswVectorsWriter(state, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene95HnswVectorsReader(state);
}
@Override
public String toString() {
return "Lucene95RWHnswVectorsFormat(name=Lucene95RWHnswVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ")";
}
}

View File

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.backward_codecs.lucene95;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene95RWCodec();
}
public void testToString() {
Lucene95RWCodec customCodec =
new Lucene95RWCodec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95RWHnswVectorsFormat(10, 20);
}
};
String expectedString =
"Lucene95RWHnswVectorsFormat(name=Lucene95RWHnswVectorsFormat, maxConn=10, beamWidth=20)";
assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString());
}
}

View File

@ -15,8 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
/** Lucene Core. */ /** Lucene Core. */
@SuppressWarnings("module") // the test framework is compiled after the core... @SuppressWarnings("module") // the test framework is compiled after the core...
@ -70,7 +70,7 @@ module org.apache.lucene.core {
provides org.apache.lucene.codecs.DocValuesFormat with provides org.apache.lucene.codecs.DocValuesFormat with
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
provides org.apache.lucene.codecs.KnnVectorsFormat with provides org.apache.lucene.codecs.KnnVectorsFormat with
Lucene95HnswVectorsFormat; Lucene99HnswVectorsFormat;
provides org.apache.lucene.codecs.PostingsFormat with provides org.apache.lucene.codecs.PostingsFormat with
org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat; org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
provides org.apache.lucene.index.SortFieldProvider with provides org.apache.lucene.index.SortFieldProvider with

View File

@ -139,7 +139,7 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
} }
/** View over multiple vector values supporting iterator-style access via DocIdMerger. */ /** View over multiple vector values supporting iterator-style access via DocIdMerger. */
protected static final class MergedVectorValues { public static final class MergedVectorValues {
private MergedVectorValues() {} private MergedVectorValues() {}
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */ /** Returns a merged view over all the segment's {@link FloatVectorValues}. */

View File

@ -81,12 +81,12 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
long vectorDataLength, long vectorDataLength,
IndexInput vectorData) IndexInput vectorData)
throws IOException { throws IOException {
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.BYTE) { if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) {
return new EmptyOffHeapVectorValues(dimension); return new EmptyOffHeapVectorValues(dimension);
} }
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength); IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
int byteSize = dimension; int byteSize = dimension;
if (configuration.docsWithFieldOffset == -1) { if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize); return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
} else { } else {
return new SparseOffHeapVectorValues( return new SparseOffHeapVectorValues(
@ -94,9 +94,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs); public abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { /**
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
* vector.
*/
public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
private int doc = -1; private int doc = -1;
@ -134,7 +138,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -203,7 +207,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -275,7 +279,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -88,9 +88,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
} }
abstract Bits getAcceptOrds(Bits acceptDocs); public abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { /**
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
* vector.
*/
public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
private int doc = -1; private int doc = -1;
@ -128,7 +132,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs; return acceptDocs;
} }
} }
@ -197,7 +201,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) { if (acceptDocs == null) {
return null; return null;
} }
@ -269,7 +273,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
} }
@Override @Override
Bits getAcceptOrds(Bits acceptDocs) { public Bits getAcceptOrds(Bits acceptDocs) {
return null; return null;
} }
} }

View File

@ -23,6 +23,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.RamUsageEstimator; import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.packed.DirectMonotonicReader;
@ -191,4 +192,48 @@ public class OrdToDocDISIReaderConfiguration implements Accountable {
public long ramBytesUsed() { public long ramBytesUsed() {
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(meta); return SHALLOW_SIZE + RamUsageEstimator.sizeOf(meta);
} }
/**
* @param dataIn the dataIn
* @return the IndexedDISI for sparse values
* @throws IOException thrown when reading data fails
*/
public IndexedDISI getIndexedDISI(IndexInput dataIn) throws IOException {
assert docsWithFieldOffset > -1;
return new IndexedDISI(
dataIn,
docsWithFieldOffset,
docsWithFieldLength,
jumpTableEntryCount,
denseRankPower,
size);
}
/**
* @param dataIn the dataIn
* @return the DirectMonotonicReader for sparse values
* @throws IOException thrown when reading data fails
*/
public DirectMonotonicReader getDirectMonotonicReader(IndexInput dataIn) throws IOException {
assert docsWithFieldOffset > -1;
final RandomAccessInput addressesData =
dataIn.randomAccessSlice(addressesOffset, addressesLength);
return DirectMonotonicReader.getInstance(meta, addressesData);
}
/**
* @return If true, the field is empty, no vector values. If false, the field is either dense or
* sparse.
*/
public boolean isEmpty() {
return docsWithFieldOffset == -2;
}
/**
* @return If true, the field is dense, all documents have values for a field. If false, the field
* is sparse, some documents missing values.
*/
public boolean isDense() {
return docsWithFieldOffset == -1;
}
} }

View File

@ -20,7 +20,6 @@ import java.util.Objects;
import org.apache.lucene.codecs.*; import org.apache.lucene.codecs.*;
import org.apache.lucene.codecs.lucene90.*; import org.apache.lucene.codecs.lucene90.*;
import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat; import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
@ -101,7 +100,7 @@ public class Lucene99Codec extends Codec {
new Lucene90StoredFieldsFormat(Objects.requireNonNull(mode).storedMode); new Lucene90StoredFieldsFormat(Objects.requireNonNull(mode).storedMode);
this.defaultPostingsFormat = new Lucene90PostingsFormat(); this.defaultPostingsFormat = new Lucene90PostingsFormat();
this.defaultDVFormat = new Lucene90DocValuesFormat(); this.defaultDVFormat = new Lucene90DocValuesFormat();
this.defaultKnnVectorsFormat = new Lucene95HnswVectorsFormat(); this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat();
} }
@Override @Override

View File

@ -0,0 +1,231 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.hnsw.HnswGraph;
/**
* Lucene 9.9 vector format, which encodes numeric vector values and an optional associated graph
* connecting the documents having values. The graph is used to power HNSW search. The format
* consists of three files, with an optional fourth file:
*
* <h2>.vec (vector data) file</h2>
*
* <p>For each field:
*
* <ul>
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
* sample is stored as an IEEE float in little-endian byte order.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
*
* <h2>.vex (vector index)</h2>
*
* <p>Stores graphs connecting the documents for each field organized as a list of nodes' neighbours
* as following:
*
* <ul>
* <li>For each level:
* <ul>
* <li>For each node:
* <ul>
* <li><b>[vint]</b> the number of neighbor nodes
* <li><b>array[vint]</b> the delta encoded neighbor ordinals
* </ul>
* </ul>
* <li>After all levels are encoded memory offsets for each node's neighbor nodes encoded by
* {@link org.apache.lucene.util.packed.DirectMonotonicWriter} are appened to the end of the
* file.
* </ul>
*
* <h2>.vem (vector metadata) file</h2>
*
* <p>For each field:
*
* <ul>
* <li><b>[int32]</b> field number
* <li><b>[int32]</b> vector similarity function ordinal
* <li><b>[byte]</b> if equals to 1 indicates if the field is for quantized vectors
* <li><b>[int32]</b> if quantized: the configured quantile float int bits.
* <li><b>[int32]</b> if quantized: the calculated lower quantile float int32 bits.
* <li><b>[int32]</b> if quantized: the calculated upper quantile float int32 bits.
* <li><b>[vlong]</b> if quantized: offset to this field's vectors in the .veq file
* <li><b>[vlong]</b> if quantized: length of this field's vectors, in bytes in the .veq file
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
* <li><b>[vlong]</b> length of this field's vectors, in bytes
* <li><b>[vlong]</b> offset to this field's index in the .vex file
* <li><b>[vlong]</b> length of this field's index data, in bytes
* <li><b>[vint]</b> dimension of this field's vectors
* <li><b>[int]</b> the number of documents having values for this field
* <li><b>[int8]</b> if equals to -1, dense all documents have values for a field. If equals to
* 0, sparse some documents missing values.
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* <li><b>[vint]</b> the maximum number of connections (neigbours) that each node can have
* <li><b>[vint]</b> number of levels in the graph
* <li>Graph nodes by level. For each level
* <ul>
* <li><b>[vint]</b> the number of nodes on this level
* <li><b>array[vint]</b> for levels greater than 0 list of nodes on this level, stored as
* the level 0th delta encoded nodes' ordinals.
* </ul>
* </ul>
*
* <h2>.veq (quantized vector data) file</h2>
*
* <p>For each field:
*
* <ul>
* <li>Vector data ordered by field, document ordinal, and vector dimension. Each vector dimension
* is stored as a single byte and every vector has a single float32 value for scoring
* corrections.
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
* note that only in sparse case
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
* that only in sparse case
* </ul>
*
* @lucene.experimental
*/
public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
static final String VECTOR_DATA_CODEC_NAME = "Lucene99HnswVectorsFormatData";
static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
static final String META_EXTENSION = "vem";
static final String VECTOR_DATA_EXTENSION = "vec";
static final String VECTOR_INDEX_EXTENSION = "vex";
public static final int VERSION_START = 0;
public static final int VERSION_CURRENT = VERSION_START;
/**
* A maximum configurable maximum max conn.
*
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
* numbers here will use an inordinate amount of heap
*/
private static final int MAXIMUM_MAX_CONN = 512;
/** Default number of maximum connections per node */
public static final int DEFAULT_MAX_CONN = 16;
/**
* The maximum size of the queue to maintain while searching during graph construction This
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
* 3200`
*/
private static final int MAXIMUM_BEAM_WIDTH = 3200;
/**
* Default number of the size of the queue maintained while searching during a graph construction.
*/
public static final int DEFAULT_BEAM_WIDTH = 100;
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
/**
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
*/
private final int maxConn;
/**
* The number of candidate neighbors to track while searching the graph for each newly inserted
* node. Defaults to to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
* HnswGraph} for details.
*/
private final int beamWidth;
/** Should this codec scalar quantize float32 vectors and use this format */
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
/** Constructs a format using default graph construction parameters */
public Lucene99HnswVectorsFormat() {
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param maxConn the maximum number of connections to a node in the HNSW graph
* @param beamWidth the size of the queue maintained during graph construction.
* @param scalarQuantize the scalar quantization format
*/
public Lucene99HnswVectorsFormat(
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
super("Lucene99HnswVectorsFormat");
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
throw new IllegalArgumentException(
"maxConn must be positive and less than or equal to"
+ MAXIMUM_MAX_CONN
+ "; maxConn="
+ maxConn);
}
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
throw new IllegalArgumentException(
"beamWidth must be positive and less than or equal to"
+ MAXIMUM_BEAM_WIDTH
+ "; beamWidth="
+ beamWidth);
}
this.maxConn = maxConn;
this.beamWidth = beamWidth;
this.scalarQuantizedVectorsFormat = scalarQuantize;
}
@Override
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
}
@Override
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
return new Lucene99HnswVectorsReader(state);
}
@Override
public int getMaxDimensions(String fieldName) {
return 1024;
}
@Override
public String toString() {
return "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn="
+ maxConn
+ ", beamWidth="
+ beamWidth
+ ", quantizer="
+ (scalarQuantizedVectorsFormat == null ? "none" : scalarQuantizedVectorsFormat.toString())
+ ")";
}
}

View File

@ -0,0 +1,634 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.HnswGraphProvider;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.CorruptIndexException;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FieldInfos;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.RandomAccessInput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
*
* @lucene.experimental
*/
public final class Lucene99HnswVectorsReader extends KnnVectorsReader
implements QuantizedVectorsReader, HnswGraphProvider {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class);
private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final IndexInput quantizedVectorData;
private final Lucene99ScalarQuantizedVectorsReader quantizedVectorsReader;
Lucene99HnswVectorsReader(SegmentReadState state) throws IOException {
this.fieldInfos = state.fieldInfos;
int versionMeta = readMetadata(state);
boolean success = false;
try {
vectorData =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME);
vectorIndex =
openDataInput(
state,
versionMeta,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
if (fields.values().stream().anyMatch(FieldEntry::hasQuantizedVectors)) {
quantizedVectorData =
openDataInput(
state,
versionMeta,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME);
quantizedVectorsReader = new Lucene99ScalarQuantizedVectorsReader(quantizedVectorData);
} else {
quantizedVectorData = null;
quantizedVectorsReader = null;
}
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
private int readMetadata(SegmentReadState state) throws IOException {
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
int versionMeta = -1;
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
Throwable priorE = null;
try {
versionMeta =
CodecUtil.checkIndexHeader(
meta,
Lucene99HnswVectorsFormat.META_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
readFields(meta, state.fieldInfos);
} catch (Throwable exception) {
priorE = exception;
} finally {
CodecUtil.checkFooter(meta, priorE);
}
}
return versionMeta;
}
private static IndexInput openDataInput(
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
throws IOException {
String fileName =
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
IndexInput in = state.directory.openInput(fileName, state.context);
boolean success = false;
try {
int versionVectorData =
CodecUtil.checkIndexHeader(
in,
codecName,
Lucene99HnswVectorsFormat.VERSION_START,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (versionMeta != versionVectorData) {
throw new CorruptIndexException(
"Format versions mismatch: meta="
+ versionMeta
+ ", "
+ codecName
+ "="
+ versionVectorData,
in);
}
CodecUtil.retrieveChecksum(in);
success = true;
return in;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(in);
}
}
}
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
FieldInfo info = infos.fieldInfo(fieldNumber);
if (info == null) {
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
}
FieldEntry fieldEntry = readField(meta);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
}
}
private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
int dimension = info.getVectorDimension();
if (dimension != fieldEntry.dimension) {
throw new IllegalStateException(
"Inconsistent vector dimension for field=\""
+ info.name
+ "\"; "
+ dimension
+ " != "
+ fieldEntry.dimension);
}
int byteSize =
switch (info.getVectorEncoding()) {
case BYTE -> Byte.BYTES;
case FLOAT32 -> Float.BYTES;
};
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
if (numBytes != fieldEntry.vectorDataLength) {
throw new IllegalStateException(
"Vector data length "
+ fieldEntry.vectorDataLength
+ " not matching size="
+ fieldEntry.size
+ " * dim="
+ dimension
+ " * byteSize="
+ byteSize
+ " = "
+ numBytes);
}
if (fieldEntry.hasQuantizedVectors()) {
Lucene99ScalarQuantizedVectorsReader.validateFieldEntry(
info, fieldEntry.dimension, fieldEntry.size, fieldEntry.quantizedVectorDataLength);
}
}
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
int similarityFunctionId = input.readInt();
if (similarityFunctionId < 0
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
throw new CorruptIndexException(
"Invalid similarity function id: " + similarityFunctionId, input);
}
return VectorSimilarityFunction.values()[similarityFunctionId];
}
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
int encodingId = input.readInt();
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
}
return VectorEncoding.values()[encodingId];
}
private FieldEntry readField(IndexInput input) throws IOException {
VectorEncoding vectorEncoding = readVectorEncoding(input);
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
return new FieldEntry(input, vectorEncoding, similarityFunction);
}
@Override
public long ramBytesUsed() {
return Lucene99HnswVectorsReader.SHALLOW_SIZE
+ RamUsageEstimator.sizeOfMap(
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
}
@Override
public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorData);
CodecUtil.checksumEntireFile(vectorIndex);
if (quantizedVectorsReader != null) {
quantizedVectorsReader.checkIntegrity();
}
}
@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
}
@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
}
return OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
}
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
}
if (fieldEntry.hasQuantizedVectors()) {
OffHeapQuantizedByteVectorValues vectorValues =
quantizedVectorsReader.getQuantizedVectorValues(
fieldEntry.quantizedOrdToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.quantizedVectorDataOffset,
fieldEntry.quantizedVectorDataLength);
if (vectorValues == null) {
return;
}
RandomVectorScorer scorer =
new ScalarQuantizedRandomVectorScorer(
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
} else {
OffHeapFloatVectorValues vectorValues =
OffHeapFloatVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}
}
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry.size() == 0
|| knnCollector.k() == 0
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return;
}
OffHeapByteVectorValues vectorValues =
OffHeapByteVectorValues.load(
fieldEntry.ordToDoc,
fieldEntry.vectorEncoding,
fieldEntry.dimension,
fieldEntry.vectorDataOffset,
fieldEntry.vectorDataLength,
vectorData);
RandomVectorScorer scorer =
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
HnswGraphSearcher.search(
scorer,
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
getGraph(fieldEntry),
vectorValues.getAcceptOrds(acceptDocs));
}
@Override
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
return HnswGraph.EMPTY;
}
}
private HnswGraph getGraph(FieldEntry entry) throws IOException {
return new OffHeapHnswGraph(entry, vectorIndex);
}
@Override
public void close() throws IOException {
IOUtils.close(vectorData, vectorIndex, quantizedVectorData);
}
@Override
public OffHeapQuantizedByteVectorValues getQuantizedVectorValues(String field)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null || fieldEntry.hasQuantizedVectors() == false) {
return null;
}
assert quantizedVectorsReader != null && fieldEntry.quantizedOrdToDoc != null;
return quantizedVectorsReader.getQuantizedVectorValues(
fieldEntry.quantizedOrdToDoc,
fieldEntry.dimension,
fieldEntry.size,
fieldEntry.quantizedVectorDataOffset,
fieldEntry.quantizedVectorDataLength);
}
@Override
public ScalarQuantizer getQuantizationState(String fieldName) {
FieldEntry field = fields.get(fieldName);
if (field == null || field.hasQuantizedVectors() == false) {
return null;
}
return field.scalarQuantizer;
}
static class FieldEntry implements Accountable {
private static final long SHALLOW_SIZE =
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
final VectorSimilarityFunction similarityFunction;
final VectorEncoding vectorEncoding;
final long vectorDataOffset;
final long vectorDataLength;
final long vectorIndexOffset;
final long vectorIndexLength;
final int M;
final int numLevels;
final int dimension;
final int size;
final int[][] nodesByLevel;
// for each level the start offsets in vectorIndex file from where to read neighbours
final DirectMonotonicReader.Meta offsetsMeta;
final long offsetsOffset;
final int offsetsBlockShift;
final long offsetsLength;
final OrdToDocDISIReaderConfiguration ordToDoc;
final float configuredQuantile, lowerQuantile, upperQuantile;
final long quantizedVectorDataOffset, quantizedVectorDataLength;
final ScalarQuantizer scalarQuantizer;
final boolean isQuantized;
final OrdToDocDISIReaderConfiguration quantizedOrdToDoc;
FieldEntry(
IndexInput input,
VectorEncoding vectorEncoding,
VectorSimilarityFunction similarityFunction)
throws IOException {
this.similarityFunction = similarityFunction;
this.vectorEncoding = vectorEncoding;
this.isQuantized = input.readByte() == 1;
// Has int8 quantization
if (isQuantized) {
configuredQuantile = Float.intBitsToFloat(input.readInt());
lowerQuantile = Float.intBitsToFloat(input.readInt());
upperQuantile = Float.intBitsToFloat(input.readInt());
quantizedVectorDataOffset = input.readVLong();
quantizedVectorDataLength = input.readVLong();
scalarQuantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, configuredQuantile);
} else {
configuredQuantile = -1;
lowerQuantile = -1;
upperQuantile = -1;
quantizedVectorDataOffset = -1;
quantizedVectorDataLength = -1;
scalarQuantizer = null;
}
vectorDataOffset = input.readVLong();
vectorDataLength = input.readVLong();
vectorIndexOffset = input.readVLong();
vectorIndexLength = input.readVLong();
dimension = input.readVInt();
size = input.readInt();
if (isQuantized) {
quantizedOrdToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
} else {
quantizedOrdToDoc = null;
}
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
// read nodes by level
M = input.readVInt();
numLevels = input.readVInt();
nodesByLevel = new int[numLevels][];
long numberOfOffsets = 0;
for (int level = 0; level < numLevels; level++) {
if (level > 0) {
int numNodesOnLevel = input.readVInt();
numberOfOffsets += numNodesOnLevel;
nodesByLevel[level] = new int[numNodesOnLevel];
nodesByLevel[level][0] = input.readVInt();
for (int i = 1; i < numNodesOnLevel; i++) {
nodesByLevel[level][i] = nodesByLevel[level][i - 1] + input.readVInt();
}
} else {
numberOfOffsets += size;
}
}
if (numberOfOffsets > 0) {
offsetsOffset = input.readLong();
offsetsBlockShift = input.readVInt();
offsetsMeta = DirectMonotonicReader.loadMeta(input, numberOfOffsets, offsetsBlockShift);
offsetsLength = input.readLong();
} else {
offsetsOffset = 0;
offsetsBlockShift = 0;
offsetsMeta = null;
offsetsLength = 0;
}
}
int size() {
return size;
}
boolean hasQuantizedVectors() {
return isQuantized;
}
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE
+ Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum()
+ RamUsageEstimator.sizeOf(ordToDoc)
+ (quantizedOrdToDoc == null ? 0 : RamUsageEstimator.sizeOf(quantizedOrdToDoc))
+ RamUsageEstimator.sizeOf(offsetsMeta);
}
}
/** Read the nearest-neighbors graph from the index input */
private static final class OffHeapHnswGraph extends HnswGraph {
final IndexInput dataIn;
final int[][] nodesByLevel;
final int numLevels;
final int entryNode;
final int size;
int arcCount;
int arcUpTo;
int arc;
private final DirectMonotonicReader graphLevelNodeOffsets;
private final long[] graphLevelNodeIndexOffsets;
// Allocated to be M*2 to track the current neighbors being explored
private final int[] currentNeighborsBuffer;
OffHeapHnswGraph(FieldEntry entry, IndexInput vectorIndex) throws IOException {
this.dataIn =
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
this.nodesByLevel = entry.nodesByLevel;
this.numLevels = entry.numLevels;
this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0;
this.size = entry.size();
final RandomAccessInput addressesData =
vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength);
this.graphLevelNodeOffsets =
DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData);
this.currentNeighborsBuffer = new int[entry.M * 2];
graphLevelNodeIndexOffsets = new long[numLevels];
graphLevelNodeIndexOffsets[0] = 0;
for (int i = 1; i < numLevels; i++) {
// nodesByLevel is `null` for the zeroth level as we know its size
int nodeCount = nodesByLevel[i - 1] == null ? size : nodesByLevel[i - 1].length;
graphLevelNodeIndexOffsets[i] = graphLevelNodeIndexOffsets[i - 1] + nodeCount;
}
}
@Override
public void seek(int level, int targetOrd) throws IOException {
int targetIndex =
level == 0
? targetOrd
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
assert targetIndex >= 0;
// unsafe; no bounds checking
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
arcCount = dataIn.readVInt();
if (arcCount > 0) {
currentNeighborsBuffer[0] = dataIn.readVInt();
for (int i = 1; i < arcCount; i++) {
currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt();
}
}
arc = -1;
arcUpTo = 0;
}
@Override
public int size() {
return size;
}
@Override
public int nextNeighbor() throws IOException {
if (arcUpTo >= arcCount) {
return NO_MORE_DOCS;
}
arc = currentNeighborsBuffer[arcUpTo];
++arcUpTo;
return arc;
}
@Override
public int numLevels() throws IOException {
return numLevels;
}
@Override
public int entryNode() throws IOException {
return entryNode;
}
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return new ArrayNodesIterator(size());
} else {
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
}
}
}
}

View File

@ -0,0 +1,985 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.ByteVectorValues;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
import org.apache.lucene.util.hnsw.NeighborArray;
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
import org.apache.lucene.util.packed.DirectMonotonicWriter;
/**
* Writes vector values and knn graphs to index segments.
*
* @lucene.experimental
*/
public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
private final SegmentWriteState segmentWriteState;
private final IndexOutput meta, vectorData, quantizedVectorData, vectorIndex;
private final int M;
private final int beamWidth;
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
private final List<FieldWriter<?>> fields = new ArrayList<>();
private boolean finished;
Lucene99HnswVectorsWriter(
SegmentWriteState state,
int M,
int beamWidth,
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
throws IOException {
this.M = M;
this.beamWidth = beamWidth;
segmentWriteState = state;
String metaFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
String vectorDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION);
String indexDataFileName =
IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
final String quantizedVectorDataFileName =
quantizedVectorsFormat != null
? IndexFileNames.segmentFileName(
state.segmentInfo.name,
state.segmentSuffix,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION)
: null;
boolean success = false;
try {
meta = state.directory.createOutput(metaFileName, state.context);
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
CodecUtil.writeIndexHeader(
meta,
Lucene99HnswVectorsFormat.META_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
vectorData,
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
CodecUtil.writeIndexHeader(
vectorIndex,
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
Lucene99HnswVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
if (quantizedVectorDataFileName != null) {
quantizedVectorData =
state.directory.createOutput(quantizedVectorDataFileName, state.context);
CodecUtil.writeIndexHeader(
quantizedVectorData,
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME,
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
state.segmentInfo.getId(),
state.segmentSuffix);
quantizedVectorsWriter =
new Lucene99ScalarQuantizedVectorsWriter(
quantizedVectorData, quantizedVectorsFormat.quantile);
} else {
quantizedVectorData = null;
quantizedVectorsWriter = null;
}
success = true;
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(this);
}
}
}
@Override
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedVectorFieldWriter =
null;
// Quantization only supports FLOAT32 for now
if (quantizedVectorsWriter != null
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
quantizedVectorFieldWriter =
quantizedVectorsWriter.addField(fieldInfo, segmentWriteState.infoStream);
}
FieldWriter<?> newField =
FieldWriter.create(
fieldInfo, M, beamWidth, segmentWriteState.infoStream, quantizedVectorFieldWriter);
fields.add(newField);
return newField;
}
@Override
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
for (FieldWriter<?> field : fields) {
long[] quantizedVectorOffsetAndLen = null;
if (field.quantizedWriter != null) {
assert quantizedVectorsWriter != null;
quantizedVectorOffsetAndLen =
quantizedVectorsWriter.flush(sortMap, field.quantizedWriter, field.docsWithField);
}
if (sortMap == null) {
writeField(field, maxDoc, quantizedVectorOffsetAndLen);
} else {
writeSortingField(field, maxDoc, sortMap, quantizedVectorOffsetAndLen);
}
}
}
@Override
public void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
if (quantizedVectorsWriter != null) {
quantizedVectorsWriter.finish();
}
if (meta != null) {
// write end of fields marker
meta.writeInt(-1);
CodecUtil.writeFooter(meta);
}
if (vectorData != null) {
CodecUtil.writeFooter(vectorData);
CodecUtil.writeFooter(vectorIndex);
}
}
@Override
public long ramBytesUsed() {
long total = 0;
for (FieldWriter<?> field : fields) {
total += field.ramBytesUsed();
}
return total;
}
private void writeField(FieldWriter<?> fieldData, int maxDoc, long[] quantizedVecOffsetAndLen)
throws IOException {
// write vector values
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectors(fieldData);
case FLOAT32 -> writeFloat32Vectors(fieldData);
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();
int[][] graphLevelNodeOffsets = writeGraph(graph);
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldData.isQuantized(),
fieldData.fieldInfo,
maxDoc,
fieldData.getConfiguredQuantile(),
fieldData.getMinQuantile(),
fieldData.getMaxQuantile(),
quantizedVecOffsetAndLen,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
fieldData.docsWithField,
graph,
graphLevelNodeOffsets);
}
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (Object v : fieldData.vectors) {
buffer.asFloatBuffer().put((float[]) v);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
}
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
for (Object v : fieldData.vectors) {
byte[] vector = (byte[]) v;
vectorData.writeBytes(vector, vector.length);
}
}
private void writeSortingField(
FieldWriter<?> fieldData,
int maxDoc,
Sorter.DocMap sortMap,
long[] quantizedVectorOffsetAndLen)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
oldOrdMap[docIdOffset - 1] = ord;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
// write vector values
long vectorDataOffset =
switch (fieldData.fieldInfo.getVectorEncoding()) {
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
};
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
// write graph
long vectorIndexOffset = vectorIndex.getFilePointer();
OnHeapHnswGraph graph = fieldData.getGraph();
int[][] graphLevelNodeOffsets = graph == null ? new int[0][] : new int[graph.numLevels()][];
HnswGraph mockGraph = reconstructAndWriteGraph(graph, ordMap, oldOrdMap, graphLevelNodeOffsets);
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
fieldData.isQuantized(),
fieldData.fieldInfo,
maxDoc,
fieldData.getConfiguredQuantile(),
fieldData.getMinQuantile(),
fieldData.getMaxQuantile(),
quantizedVectorOffsetAndLen,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
newDocsWithField,
mockGraph,
graphLevelNodeOffsets);
}
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
final ByteBuffer buffer =
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int ordinal : ordMap) {
float[] vector = (float[]) fieldData.vectors.get(ordinal);
buffer.asFloatBuffer().put(vector);
vectorData.writeBytes(buffer.array(), buffer.array().length);
}
return vectorDataOffset;
}
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
for (int ordinal : ordMap) {
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
vectorData.writeBytes(vector, vector.length);
}
return vectorDataOffset;
}
/**
* Reconstructs the graph given the old and new node ids.
*
* <p>Additionally, the graph node connections are written to the vectorIndex.
*
* @param graph The current on heap graph
* @param newToOldMap the new node ids indexed to the old node ids
* @param oldToNewMap the old node ids indexed to the new node ids
* @param levelNodeOffsets where to place the new offsets for the nodes in the vector index.
* @return The graph
* @throws IOException if writing to vectorIndex fails
*/
private HnswGraph reconstructAndWriteGraph(
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap, int[][] levelNodeOffsets)
throws IOException {
if (graph == null) return null;
List<int[]> nodesByLevel = new ArrayList<>(graph.numLevels());
nodesByLevel.add(null);
int maxOrd = graph.size();
NodesIterator nodesOnLevel0 = graph.getNodesOnLevel(0);
levelNodeOffsets[0] = new int[nodesOnLevel0.size()];
while (nodesOnLevel0.hasNext()) {
int node = nodesOnLevel0.nextInt();
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
for (int level = 1; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
int[] newNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
}
Arrays.sort(newNodes);
nodesByLevel.add(newNodes);
levelNodeOffsets[level] = new int[newNodes.length];
int nodeOffsetIndex = 0;
for (int node : newNodes) {
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
long offset = vectorIndex.getFilePointer();
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
levelNodeOffsets[level][nodeOffsetIndex++] =
Math.toIntExact(vectorIndex.getFilePointer() - offset);
}
}
return new HnswGraph() {
@Override
public int nextNeighbor() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public void seek(int level, int target) {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public int size() {
return graph.size();
}
@Override
public int numLevels() {
return graph.numLevels();
}
@Override
public int entryNode() {
throw new UnsupportedOperationException("Not supported on a mock graph");
}
@Override
public NodesIterator getNodesOnLevel(int level) {
if (level == 0) {
return graph.getNodesOnLevel(0);
} else {
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
}
}
};
}
private void reconstructAndWriteNeigbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
throws IOException {
int size = neighbors.size();
vectorIndex.writeVInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
for (int i = 0; i < size; i++) {
nnodes[i] = oldToNewMap[nnodes[i]];
}
Arrays.sort(nnodes, 0, size);
// Now that we have sorted, do delta encoding to minimize the required bits to store the
// information
for (int i = size - 1; i > 0; --i) {
assert nnodes[i] < maxOrd : "node too large: " + nnodes[i] + ">=" + maxOrd;
nnodes[i] -= nnodes[i - 1];
}
for (int i = 0; i < size; i++) {
vectorIndex.writeVInt(nnodes[i]);
}
}
@Override
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
IndexOutput tempVectorData = null;
IndexInput vectorDataInput = null;
CloseableRandomVectorScorerSupplier scorerSupplier = null;
boolean success = false;
try {
ScalarQuantizer scalarQuantizer = null;
long[] quantizedVectorDataOffsetAndLength = null;
// If we have configured quantization and are FLOAT32
if (quantizedVectorsWriter != null
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
// We need the quantization parameters to write to the meta file
scalarQuantizer = quantizedVectorsWriter.mergeQuantiles(fieldInfo, mergeState);
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
segmentWriteState.infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"Merged quantiles field: "
+ fieldInfo.name
+ " newly merged quantile: "
+ scalarQuantizer);
}
assert scalarQuantizer != null;
quantizedVectorDataOffsetAndLength = new long[2];
quantizedVectorDataOffsetAndLength[0] = quantizedVectorData.alignFilePointer(Float.BYTES);
scorerSupplier =
quantizedVectorsWriter.mergeOneField(
segmentWriteState, fieldInfo, mergeState, scalarQuantizer);
quantizedVectorDataOffsetAndLength[1] =
quantizedVectorData.getFilePointer() - quantizedVectorDataOffsetAndLength[0];
}
final DocsWithFieldSet docsWithField;
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
// If we extract vector storage, this could be cleaner.
// But for now, vector storage & index creation/storage live together.
if (scorerSupplier == null) {
tempVectorData =
segmentWriteState.directory.createTempOutput(
vectorData.getName(), "temp", segmentWriteState.context);
docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
CodecUtil.writeFooter(tempVectorData);
IOUtils.close(tempVectorData);
// copy the temporary file vectors to the actual data file
vectorDataInput =
segmentWriteState.directory.openInput(
tempVectorData.getName(), segmentWriteState.context);
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(vectorDataInput);
final RandomVectorScorerSupplier innerScoreSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(),
docsWithField.cardinality(),
vectorDataInput,
byteSize),
fieldInfo.getVectorSimilarityFunction());
};
final String tempFileName = tempVectorData.getName();
final IndexInput finalVectorDataInput = vectorDataInput;
scorerSupplier =
new CloseableRandomVectorScorerSupplier() {
boolean closed = false;
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return innerScoreSupplier.scorer(ord);
}
@Override
public void close() throws IOException {
if (closed) {
return;
}
closed = true;
IOUtils.close(finalVectorDataInput);
segmentWriteState.directory.deleteFile(tempFileName);
}
};
} else {
// No need to use temporary file as we don't have to re-open for reading
docsWithField =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> writeByteVectorData(
vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
case FLOAT32 -> writeVectorData(
vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
};
}
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
long vectorIndexOffset = vectorIndex.getFilePointer();
// build the graph using the temporary vector data
// we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
// doesn't need to know docIds
// TODO: separate random access vector values from DocIdSetIterator?
OnHeapHnswGraph graph = null;
int[][] vectorIndexNodeOffsets = null;
if (docsWithField.cardinality() != 0) {
// build graph
IncrementalHnswGraphMerger merger =
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
merger.addReader(
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
}
DocIdSetIterator mergedVectorIterator = null;
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
case FLOAT32 -> mergedVectorIterator =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
}
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
graph = hnswGraphBuilder.build(docsWithField.cardinality());
vectorIndexNodeOffsets = writeGraph(graph);
}
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
writeMeta(
scalarQuantizer != null,
fieldInfo,
segmentWriteState.segmentInfo.maxDoc(),
scalarQuantizer == null ? null : scalarQuantizer.getConfiguredQuantile(),
scalarQuantizer == null ? null : scalarQuantizer.getLowerQuantile(),
scalarQuantizer == null ? null : scalarQuantizer.getUpperQuantile(),
quantizedVectorDataOffsetAndLength,
vectorDataOffset,
vectorDataLength,
vectorIndexOffset,
vectorIndexLength,
docsWithField,
graph,
vectorIndexNodeOffsets);
success = true;
} finally {
if (success) {
IOUtils.close(scorerSupplier);
} else {
IOUtils.closeWhileHandlingException(scorerSupplier, vectorDataInput, tempVectorData);
if (tempVectorData != null) {
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempVectorData.getName());
}
}
}
}
/**
* @param graph Write the graph in a compressed format
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
* @throws IOException if writing to vectorIndex fails
*/
private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
if (graph == null) return new int[0][0];
// write vectors' neighbours on each level into the vectorIndex file
int countOnLevel0 = graph.size();
int[][] offsets = new int[graph.numLevels()][];
for (int level = 0; level < graph.numLevels(); level++) {
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
offsets[level] = new int[sortedNodes.length];
int nodeOffsetId = 0;
for (int node : sortedNodes) {
NeighborArray neighbors = graph.getNeighbors(level, node);
int size = neighbors.size();
// Write size in VInt as the neighbors list is typically small
long offsetStart = vectorIndex.getFilePointer();
vectorIndex.writeVInt(size);
// Destructively modify; it's ok we are discarding it after this
int[] nnodes = neighbors.node();
Arrays.sort(nnodes, 0, size);
// Now that we have sorted, do delta encoding to minimize the required bits to store the
// information
for (int i = size - 1; i > 0; --i) {
assert nnodes[i] < countOnLevel0 : "node too large: " + nnodes[i] + ">=" + countOnLevel0;
nnodes[i] -= nnodes[i - 1];
}
for (int i = 0; i < size; i++) {
vectorIndex.writeVInt(nnodes[i]);
}
offsets[level][nodeOffsetId++] =
Math.toIntExact(vectorIndex.getFilePointer() - offsetStart);
}
}
return offsets;
}
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
int[] sortedNodes = new int[nodesOnLevel.size()];
for (int n = 0; nodesOnLevel.hasNext(); n++) {
sortedNodes[n] = nodesOnLevel.nextInt();
}
Arrays.sort(sortedNodes);
return sortedNodes;
}
private void writeMeta(
boolean isQuantized,
FieldInfo field,
int maxDoc,
Float configuredQuantizationQuantile,
Float lowerQuantile,
Float upperQuantile,
long[] quantizedVectorDataOffsetAndLen,
long vectorDataOffset,
long vectorDataLength,
long vectorIndexOffset,
long vectorIndexLength,
DocsWithFieldSet docsWithField,
HnswGraph graph,
int[][] graphLevelNodeOffsets)
throws IOException {
meta.writeInt(field.number);
meta.writeInt(field.getVectorEncoding().ordinal());
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
meta.writeByte(isQuantized ? (byte) 1 : (byte) 0);
if (isQuantized) {
assert lowerQuantile != null
&& upperQuantile != null
&& quantizedVectorDataOffsetAndLen != null;
assert quantizedVectorDataOffsetAndLen.length == 2;
meta.writeInt(
Float.floatToIntBits(
configuredQuantizationQuantile != null
? configuredQuantizationQuantile
: calculateDefaultQuantile(field.getVectorDimension())));
meta.writeInt(Float.floatToIntBits(lowerQuantile));
meta.writeInt(Float.floatToIntBits(upperQuantile));
meta.writeVLong(quantizedVectorDataOffsetAndLen[0]);
meta.writeVLong(quantizedVectorDataOffsetAndLen[1]);
} else {
assert configuredQuantizationQuantile == null
&& lowerQuantile == null
&& upperQuantile == null
&& quantizedVectorDataOffsetAndLen == null;
}
meta.writeVLong(vectorDataOffset);
meta.writeVLong(vectorDataLength);
meta.writeVLong(vectorIndexOffset);
meta.writeVLong(vectorIndexLength);
meta.writeVInt(field.getVectorDimension());
// write docIDs
int count = docsWithField.cardinality();
meta.writeInt(count);
if (isQuantized) {
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
}
OrdToDocDISIReaderConfiguration.writeStoredMeta(
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
meta.writeVInt(M);
// write graph nodes on each level
if (graph == null) {
meta.writeVInt(0);
} else {
meta.writeVInt(graph.numLevels());
long valueCount = 0;
for (int level = 0; level < graph.numLevels(); level++) {
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
valueCount += nodesOnLevel.size();
if (level > 0) {
int[] nol = new int[nodesOnLevel.size()];
int numberConsumed = nodesOnLevel.consume(nol);
Arrays.sort(nol);
assert numberConsumed == nodesOnLevel.size();
meta.writeVInt(nol.length); // number of nodes on a level
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {
nol[i] -= nol[i - 1];
}
for (int n : nol) {
assert n >= 0 : "delta encoding for nodes failed; expected nodes to be sorted";
meta.writeVInt(n);
}
} else {
assert nodesOnLevel.size() == count : "Level 0 expects to have all nodes";
}
}
long start = vectorIndex.getFilePointer();
meta.writeLong(start);
meta.writeVInt(DIRECT_MONOTONIC_BLOCK_SHIFT);
final DirectMonotonicWriter memoryOffsetsWriter =
DirectMonotonicWriter.getInstance(
meta, vectorIndex, valueCount, DIRECT_MONOTONIC_BLOCK_SHIFT);
long cumulativeOffsetSum = 0;
for (int[] levelOffsets : graphLevelNodeOffsets) {
for (int v : levelOffsets) {
memoryOffsetsWriter.add(cumulativeOffsetSum);
cumulativeOffsetSum += v;
}
}
memoryOffsetsWriter.finish();
meta.writeLong(vectorIndex.getFilePointer() - start);
}
}
/**
* Writes the byte vector values to the output and returns a set of documents that contains
* vectors.
*/
private static DocsWithFieldSet writeByteVectorData(
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = byteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = byteVectorValues.nextDoc()) {
// write vector
byte[] binaryValue = byteVectorValues.vectorValue();
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
output.writeBytes(binaryValue, binaryValue.length);
docsWithField.add(docV);
}
return docsWithField;
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeVectorData(
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
ByteBuffer buffer =
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
.order(ByteOrder.LITTLE_ENDIAN);
for (int docV = floatVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = floatVectorValues.nextDoc()) {
// write vector
float[] value = floatVectorValues.vectorValue();
buffer.asFloatBuffer().put(value);
output.writeBytes(buffer.array(), buffer.limit());
docsWithField.add(docV);
}
return docsWithField;
}
@Override
public void close() throws IOException {
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
}
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
private final FieldInfo fieldInfo;
private final int dim;
private final DocsWithFieldSet docsWithField;
private final List<T> vectors;
private final HnswGraphBuilder hnswGraphBuilder;
private final Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter
quantizedWriter;
private int lastDocID = -1;
private int node = 0;
static FieldWriter<?> create(
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream,
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter writer)
throws IOException {
int dim = fieldInfo.getVectorDimension();
return switch (fieldInfo.getVectorEncoding()) {
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream, writer) {
@Override
public byte[] copyValue(byte[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream, writer) {
@Override
public float[] copyValue(float[] value) {
return ArrayUtil.copyOfSubArray(value, 0, dim);
}
};
};
}
@SuppressWarnings("unchecked")
FieldWriter(
FieldInfo fieldInfo,
int M,
int beamWidth,
InfoStream infoStream,
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedWriter)
throws IOException {
this.fieldInfo = fieldInfo;
this.dim = fieldInfo.getVectorDimension();
this.docsWithField = new DocsWithFieldSet();
this.quantizedWriter = quantizedWriter;
vectors = new ArrayList<>();
if (quantizedWriter != null
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
throw new IllegalArgumentException(
"Vector encoding ["
+ VectorEncoding.FLOAT32
+ "] required for quantized vectors; provided="
+ fieldInfo.getVectorEncoding());
}
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
RandomVectorScorerSupplier scorerSupplier =
switch (fieldInfo.getVectorEncoding()) {
case BYTE -> RandomVectorScorerSupplier.createBytes(
(RandomAccessVectorValues<byte[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
(RandomAccessVectorValues<float[]>) raVectors,
fieldInfo.getVectorSimilarityFunction());
};
hnswGraphBuilder =
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
hnswGraphBuilder.setInfoStream(infoStream);
}
@Override
public void addValue(int docID, T vectorValue) throws IOException {
if (docID == lastDocID) {
throw new IllegalArgumentException(
"VectorValuesField \""
+ fieldInfo.name
+ "\" appears more than once in this document (only one value is allowed per field)");
}
assert docID > lastDocID;
T copy = copyValue(vectorValue);
if (quantizedWriter != null) {
assert vectorValue instanceof float[];
quantizedWriter.addValue((float[]) copy);
}
docsWithField.add(docID);
vectors.add(copy);
hnswGraphBuilder.addGraphNode(node);
node++;
lastDocID = docID;
}
OnHeapHnswGraph getGraph() {
if (vectors.size() > 0) {
return hnswGraphBuilder.getGraph();
} else {
return null;
}
}
@Override
public long ramBytesUsed() {
if (vectors.size() == 0) return 0;
long quantizationSpace = quantizedWriter != null ? quantizedWriter.ramBytesUsed() : 0L;
return docsWithField.ramBytesUsed()
+ (long) vectors.size()
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
+ (long) vectors.size()
* fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize
+ hnswGraphBuilder.getGraph().ramBytesUsed()
+ quantizationSpace;
}
Float getConfiguredQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getQuantile();
}
Float getMinQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getMinQuantile();
}
Float getMaxQuantile() {
return quantizedWriter == null ? null : quantizedWriter.getMaxQuantile();
}
boolean isQuantized() {
return quantizedWriter != null;
}
}
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
private final List<T> vectors;
private final int dim;
RAVectorValues(List<T> vectors, int dim) {
this.vectors = vectors;
this.dim = dim;
}
@Override
public int size() {
return vectors.size();
}
@Override
public int dimension() {
return dim;
}
@Override
public T vectorValue(int targetOrd) throws IOException {
return vectors.get(targetOrd);
}
@Override
public RandomAccessVectorValues<T> copy() throws IOException {
return this;
}
}
}

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
/**
* Format supporting vector quantization, storage, and retrieval
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsFormat {
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
static final int VERSION_START = 0;
static final int VERSION_CURRENT = VERSION_START;
static final String QUANTIZED_VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsData";
static final String QUANTIZED_VECTOR_DATA_EXTENSION = "veq";
/** The minimum quantile */
private static final float MINIMUM_QUANTILE = 0.9f;
/** The maximum quantile */
private static final float MAXIMUM_QUANTILE = 1f;
/**
* Controls the quantile used to scalar quantize the vectors the default quantile is calculated as
* `1-1/(vector_dimensions + 1)`
*/
final Float quantile;
/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
this(null);
}
/**
* Constructs a format using the given graph construction parameters.
*
* @param quantile the quantile for scalar quantizing the vectors, when `null` it is calculated
* based on the vector field dimensions.
*/
public Lucene99ScalarQuantizedVectorsFormat(Float quantile) {
if (quantile != null && (quantile < MINIMUM_QUANTILE || quantile > MAXIMUM_QUANTILE)) {
throw new IllegalArgumentException(
"quantile must be between "
+ MINIMUM_QUANTILE
+ " and "
+ MAXIMUM_QUANTILE
+ "; quantile="
+ quantile);
}
this.quantile = quantile;
}
static float calculateDefaultQuantile(int vectorDimension) {
return Math.max(MINIMUM_QUANTILE, 1f - (1f / (vectorDimension + 1)));
}
@Override
public String toString() {
return NAME + "(name=" + NAME + ", quantile=" + quantile + ")";
}
}

View File

@ -0,0 +1,88 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.store.IndexInput;
/**
* Reads Scalar Quantized vectors from the index segments along with index data structures.
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsReader {
private final IndexInput quantizedVectorData;
Lucene99ScalarQuantizedVectorsReader(IndexInput quantizedVectorData) {
this.quantizedVectorData = quantizedVectorData;
}
static void validateFieldEntry(
FieldInfo info, int fieldDimension, int size, long quantizedVectorDataLength) {
int dimension = info.getVectorDimension();
if (dimension != fieldDimension) {
throw new IllegalStateException(
"Inconsistent vector dimension for field=\""
+ info.name
+ "\"; "
+ dimension
+ " != "
+ fieldDimension);
}
// int8 quantized and calculated stored offset.
long quantizedVectorBytes = dimension + Float.BYTES;
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, size);
if (numQuantizedVectorBytes != quantizedVectorDataLength) {
throw new IllegalStateException(
"Quantized vector data length "
+ quantizedVectorDataLength
+ " not matching size="
+ size
+ " * (dim="
+ dimension
+ " + 4)"
+ " = "
+ numQuantizedVectorBytes);
}
}
public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(quantizedVectorData);
}
OffHeapQuantizedByteVectorValues getQuantizedVectorValues(
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
long quantizedVectorDataOffset,
long quantizedVectorDataLength)
throws IOException {
return OffHeapQuantizedByteVectorValues.load(
configuration,
dimension,
size,
quantizedVectorDataOffset,
quantizedVectorDataLength,
quantizedVectorData);
}
}

View File

@ -0,0 +1,824 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.index.DocIDMerger;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.MergeState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.index.Sorter;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/**
* Writes quantized vector values and metadata to index segments.
*
* @lucene.experimental
*/
public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
private static final long BASE_RAM_BYTES_USED =
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
// Used for determining when merged quantiles shifted too far from individual segment quantiles.
// When merging quantiles from various segments, we need to ensure that the new quantiles
// are not exceptionally different from an individual segments quantiles.
// This would imply that the quantization buckets would shift too much
// for floating point values and justify recalculating the quantiles. This helps preserve
// accuracy of the calculated quantiles, even in adversarial cases such as vector clustering.
// This number was determined via empirical testing
private static final float QUANTILE_RECOMPUTE_LIMIT = 32;
// Used for determining if a new quantization state requires a re-quantization
// for a given segment.
// This ensures that in expectation 4/5 of the vector would be unchanged by requantization.
// Furthermore, only those values where the value is within 1/5 of the centre of a quantization
// bin will be changed. In these cases the error introduced by snapping one way or another
// is small compared to the error introduced by quantization in the first place. Furthermore,
// empirical testing showed that the relative error by not requantizing is small (compared to
// the quantization error) and the condition is sensitive enough to detect all adversarial cases,
// such as merging clustered data.
private static final float REQUANTIZATION_LIMIT = 0.2f;
private final IndexOutput quantizedVectorData;
private final Float quantile;
private boolean finished;
Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) {
this.quantile = quantile;
this.quantizedVectorData = quantizedVectorData;
}
QuantizationFieldVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) {
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"Only float32 vector fields are supported for quantization");
}
float quantile =
this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
: this.quantile;
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"quantizing field="
+ fieldInfo.name
+ " dimension="
+ fieldInfo.getVectorDimension()
+ " quantile="
+ quantile);
}
return QuantizationFieldVectorWriter.create(fieldInfo, quantile, infoStream);
}
long[] flush(
Sorter.DocMap sortMap, QuantizationFieldVectorWriter field, DocsWithFieldSet docsWithField)
throws IOException {
field.finish();
return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField);
}
void finish() throws IOException {
if (finished) {
throw new IllegalStateException("already finished");
}
finished = true;
if (quantizedVectorData != null) {
CodecUtil.writeFooter(quantizedVectorData);
}
}
private long[] writeField(QuantizationFieldVectorWriter fieldData) throws IOException {
long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeQuantizedVectors(fieldData);
long quantizedVectorDataLength =
quantizedVectorData.getFilePointer() - quantizedVectorDataOffset;
return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength};
}
private void writeQuantizedVectors(QuantizationFieldVectorWriter fieldData) throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (float[] v : fieldData.floatVectors) {
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length);
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
}
}
private long[] writeSortingField(
QuantizationFieldVectorWriter fieldData,
Sorter.DocMap sortMap,
DocsWithFieldSet docsWithField)
throws IOException {
final int[] docIdOffsets = new int[sortMap.size()];
int offset = 1; // 0 means no vector for this (field, document)
DocIdSetIterator iterator = docsWithField.iterator();
for (int docID = iterator.nextDoc();
docID != DocIdSetIterator.NO_MORE_DOCS;
docID = iterator.nextDoc()) {
int newDocID = sortMap.oldToNew(docID);
docIdOffsets[newDocID] = offset++;
}
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
final int[] ordMap = new int[offset - 1]; // new ord to old ord
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
int ord = 0;
int doc = 0;
for (int docIdOffset : docIdOffsets) {
if (docIdOffset != 0) {
ordMap[ord] = docIdOffset - 1;
oldOrdMap[docIdOffset - 1] = ord;
newDocsWithField.add(doc);
ord++;
}
doc++;
}
// write vector values
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
writeSortedQuantizedVectors(fieldData, ordMap);
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
return new long[] {vectorDataOffset, quantizedVectorLength};
}
void writeSortedQuantizedVectors(QuantizationFieldVectorWriter fieldData, int[] ordMap)
throws IOException {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal);
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length);
offsetBuffer.putFloat(offsetCorrection);
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
offsetBuffer.rewind();
}
}
ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
float quantile =
this.quantile == null
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
: this.quantile;
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile);
}
ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField(
SegmentWriteState segmentWriteState,
FieldInfo fieldInfo,
MergeState mergeState,
ScalarQuantizer mergedQuantizationState)
throws IOException {
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
IndexOutput tempQuantizedVectorData =
segmentWriteState.directory.createTempOutput(
quantizedVectorData.getName(), "temp", segmentWriteState.context);
IndexInput quantizationDataInput = null;
boolean success = false;
try {
MergedQuantizedVectorValues byteVectorValues =
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
fieldInfo, mergeState, mergedQuantizationState);
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
CodecUtil.writeFooter(tempQuantizedVectorData);
IOUtils.close(tempQuantizedVectorData);
quantizationDataInput =
segmentWriteState.directory.openInput(
tempQuantizedVectorData.getName(), segmentWriteState.context);
quantizedVectorData.copyBytes(
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
CodecUtil.retrieveChecksum(quantizationDataInput);
success = true;
final IndexInput finalQuantizationDataInput = quantizationDataInput;
return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
() -> {
IOUtils.close(finalQuantizationDataInput);
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
},
new ScalarQuantizedRandomVectorScorerSupplier(
fieldInfo.getVectorSimilarityFunction(),
mergedQuantizationState,
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput)));
} finally {
if (success == false) {
IOUtils.closeWhileHandlingException(quantizationDataInput);
IOUtils.deleteFilesIgnoringExceptions(
segmentWriteState.directory, tempQuantizedVectorData.getName());
}
}
}
static ScalarQuantizer mergeQuantiles(
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, float quantile) {
assert quantizationStates.size() == segmentSizes.size();
if (quantizationStates.isEmpty()) {
return null;
}
float lowerQuantile = 0f;
float upperQuantile = 0f;
int totalCount = 0;
for (int i = 0; i < quantizationStates.size(); i++) {
if (quantizationStates.get(i) == null) {
return null;
}
lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i);
upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i);
totalCount += segmentSizes.get(i);
}
lowerQuantile /= totalCount;
upperQuantile /= totalCount;
return new ScalarQuantizer(lowerQuantile, upperQuantile, quantile);
}
/**
* Returns true if the quantiles of the merged state are too far from the quantiles of the
* individual states.
*
* @param mergedQuantizationState The merged quantization state
* @param quantizationStates The quantization states of the individual segments
* @return true if the quantiles should be recomputed
*/
static boolean shouldRecomputeQuantiles(
ScalarQuantizer mergedQuantizationState, List<ScalarQuantizer> quantizationStates) {
// calculate the limit for the quantiles to be considered too far apart
// We utilize upper & lower here to determine if the new upper and merged upper would
// drastically
// change the quantization buckets for floats
// This is a fairly conservative check.
float limit =
(mergedQuantizationState.getUpperQuantile() - mergedQuantizationState.getLowerQuantile())
/ QUANTILE_RECOMPUTE_LIMIT;
for (ScalarQuantizer quantizationState : quantizationStates) {
if (Math.abs(
quantizationState.getUpperQuantile() - mergedQuantizationState.getUpperQuantile())
> limit) {
return true;
}
if (Math.abs(
quantizationState.getLowerQuantile() - mergedQuantizationState.getLowerQuantile())
> limit) {
return true;
}
}
return false;
}
private static QuantizedVectorsReader getQuantizedKnnVectorsReader(
KnnVectorsReader vectorsReader, String fieldName) {
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
vectorsReader = candidateReader.getFieldReader(fieldName);
}
if (vectorsReader instanceof QuantizedVectorsReader reader) {
return reader;
}
return null;
}
private static ScalarQuantizer getQuantizedState(
KnnVectorsReader vectorsReader, String fieldName) {
QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(vectorsReader, fieldName);
if (reader != null) {
return reader.getQuantizationState(fieldName);
}
return null;
}
static ScalarQuantizer mergeAndRecalculateQuantiles(
MergeState mergeState, FieldInfo fieldInfo, float quantile) throws IOException {
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
for (int i = 0; i < mergeState.liveDocs.length; i++) {
FloatVectorValues fvv;
if (mergeState.knnVectorsReaders[i] != null
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
&& fvv.size() > 0) {
ScalarQuantizer quantizationState =
getQuantizedState(mergeState.knnVectorsReaders[i], fieldInfo.name);
// If we have quantization state, we can utilize that to make merging cheaper
quantizationStates.add(quantizationState);
segmentSizes.add(fvv.size());
}
}
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, quantile);
// Segments no providing quantization state indicates that their quantiles were never
// calculated.
// To be safe, we should always recalculate given a sample set over all the float vectors in the
// merged
// segment view
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
FloatVectorValues vectorValues =
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, quantile);
}
return mergedQuantiles;
}
/**
* Returns true if the quantiles of the new quantization state are too far from the quantiles of
* the existing quantization state. This would imply that floating point values would slightly
* shift quantization buckets.
*
* @param existingQuantiles The existing quantiles for a segment
* @param newQuantiles The new quantiles for a segment, could be merged, or fully re-calculated
* @return true if the floating point values should be requantized
*/
static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) {
float tol =
REQUANTIZATION_LIMIT
* (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile())
/ 128f;
if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) {
return true;
}
return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol;
}
/**
* Writes the vector values to the output and returns a set of documents that contains vectors.
*/
private static DocsWithFieldSet writeQuantizedVectorData(
IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
for (int docV = quantizedByteVectorValues.nextDoc();
docV != NO_MORE_DOCS;
docV = quantizedByteVectorValues.nextDoc()) {
// write vector
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
assert binaryValue.length == quantizedByteVectorValues.dimension()
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
output.writeBytes(binaryValue, binaryValue.length);
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
docsWithField.add(docV);
}
return docsWithField;
}
@Override
public long ramBytesUsed() {
return BASE_RAM_BYTES_USED;
}
static class QuantizationFieldVectorWriter implements Accountable {
private static final long SHALLOW_SIZE =
shallowSizeOfInstance(QuantizationFieldVectorWriter.class);
private final int dim;
private final List<float[]> floatVectors;
private final boolean normalize;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final float quantile;
private final InfoStream infoStream;
private float minQuantile = Float.POSITIVE_INFINITY;
private float maxQuantile = Float.NEGATIVE_INFINITY;
private boolean finished;
static QuantizationFieldVectorWriter create(
FieldInfo fieldInfo, float quantile, InfoStream infoStream) {
return new QuantizationFieldVectorWriter(
fieldInfo.getVectorDimension(),
quantile,
fieldInfo.getVectorSimilarityFunction(),
infoStream);
}
QuantizationFieldVectorWriter(
int dim,
float quantile,
VectorSimilarityFunction vectorSimilarityFunction,
InfoStream infoStream) {
this.dim = dim;
this.quantile = quantile;
this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE;
this.vectorSimilarityFunction = vectorSimilarityFunction;
this.floatVectors = new ArrayList<>();
this.infoStream = infoStream;
}
void finish() throws IOException {
if (finished) {
return;
}
if (floatVectors.size() == 0) {
finished = true;
return;
}
ScalarQuantizer quantizer =
ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile);
minQuantile = quantizer.getLowerQuantile();
maxQuantile = quantizer.getUpperQuantile();
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
infoStream.message(
QUANTIZED_VECTOR_COMPONENT,
"quantized field="
+ " dimension="
+ dim
+ " quantile="
+ quantile
+ " minQuantile="
+ minQuantile
+ " maxQuantile="
+ maxQuantile);
}
finished = true;
}
public void addValue(float[] vectorValue) throws IOException {
floatVectors.add(vectorValue);
}
float getMinQuantile() {
assert finished;
return minQuantile;
}
float getMaxQuantile() {
assert finished;
return maxQuantile;
}
float getQuantile() {
return quantile;
}
ScalarQuantizer createQuantizer() {
assert finished;
return new ScalarQuantizer(minQuantile, maxQuantile, quantile);
}
@Override
public long ramBytesUsed() {
if (floatVectors.size() == 0) return SHALLOW_SIZE;
return SHALLOW_SIZE + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
}
}
static class FloatVectorWrapper extends FloatVectorValues {
private final List<float[]> vectorList;
private final float[] copy;
private final boolean normalize;
protected int curDoc = -1;
FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
this.vectorList = vectorList;
this.copy = new float[vectorList.get(0).length];
this.normalize = normalize;
}
@Override
public int dimension() {
return vectorList.get(0).length;
}
@Override
public int size() {
return vectorList.size();
}
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= vectorList.size()) {
throw new IOException("Current doc not set or too many iterations");
}
if (normalize) {
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
return copy;
}
return vectorList.get(curDoc);
}
@Override
public int docID() {
if (curDoc >= vectorList.size()) {
return NO_MORE_DOCS;
}
return curDoc;
}
@Override
public int nextDoc() throws IOException {
curDoc++;
return docID();
}
@Override
public int advance(int target) throws IOException {
curDoc = target;
return docID();
}
}
private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
private final QuantizedByteVectorValues values;
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
super(docMap);
this.values = values;
assert values.docID() == -1;
}
@Override
public int nextDoc() throws IOException {
return values.nextDoc();
}
}
/** Returns a merged view over all the segment's {@link QuantizedByteVectorValues}. */
static class MergedQuantizedVectorValues extends QuantizedByteVectorValues {
public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues(
FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer scalarQuantizer)
throws IOException {
assert fieldInfo != null && fieldInfo.hasVectorValues();
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
if (mergeState.knnVectorsReaders[i] != null
&& mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) {
QuantizedVectorsReader reader =
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
assert scalarQuantizer != null;
final QuantizedByteVectorValueSub sub;
// Either our quantization parameters are way different than the merged ones
// Or we have never been quantized.
if (reader == null
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
sub =
new QuantizedByteVectorValueSub(
mergeState.docMaps[i],
new QuantizedFloatVectorValues(
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name),
fieldInfo.getVectorSimilarityFunction(),
scalarQuantizer));
} else {
sub =
new QuantizedByteVectorValueSub(
mergeState.docMaps[i],
new OffsetCorrectedQuantizedByteVectorValues(
reader.getQuantizedVectorValues(fieldInfo.name),
fieldInfo.getVectorSimilarityFunction(),
scalarQuantizer,
reader.getQuantizationState(fieldInfo.name)));
}
subs.add(sub);
}
}
return new MergedQuantizedVectorValues(subs, mergeState);
}
private final List<QuantizedByteVectorValueSub> subs;
private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
private final int size;
private int docId;
private QuantizedByteVectorValueSub current;
private MergedQuantizedVectorValues(
List<QuantizedByteVectorValueSub> subs, MergeState mergeState) throws IOException {
this.subs = subs;
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
int totalSize = 0;
for (QuantizedByteVectorValueSub sub : subs) {
totalSize += sub.values.size();
}
size = totalSize;
docId = -1;
}
@Override
public byte[] vectorValue() throws IOException {
return current.values.vectorValue();
}
@Override
public int docID() {
return docId;
}
@Override
public int nextDoc() throws IOException {
current = docIdMerger.next();
if (current == null) {
docId = NO_MORE_DOCS;
} else {
docId = current.mappedDocID;
}
return docId;
}
@Override
public int advance(int target) {
throw new UnsupportedOperationException();
}
@Override
public int size() {
return size;
}
@Override
public int dimension() {
return subs.get(0).values.dimension();
}
@Override
float getScoreCorrectionConstant() throws IOException {
return current.values.getScoreCorrectionConstant();
}
}
private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
private float offsetValue = 0f;
private final VectorSimilarityFunction vectorSimilarityFunction;
public QuantizedFloatVectorValues(
FloatVectorValues values,
VectorSimilarityFunction vectorSimilarityFunction,
ScalarQuantizer quantizer) {
this.values = values;
this.quantizer = quantizer;
this.quantizedVector = new byte[values.dimension()];
this.vectorSimilarityFunction = vectorSimilarityFunction;
}
@Override
float getScoreCorrectionConstant() {
return offsetValue;
}
@Override
public int dimension() {
return values.dimension();
}
@Override
public int size() {
return values.size();
}
@Override
public byte[] vectorValue() throws IOException {
return quantizedVector;
}
@Override
public int docID() {
return values.docID();
}
@Override
public int nextDoc() throws IOException {
int doc = values.nextDoc();
if (doc != NO_MORE_DOCS) {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
return doc;
}
@Override
public int advance(int target) throws IOException {
int doc = values.advance(target);
if (doc != NO_MORE_DOCS) {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
return doc;
}
}
static final class ScalarQuantizedCloseableRandomVectorScorerSupplier
implements CloseableRandomVectorScorerSupplier {
private final ScalarQuantizedRandomVectorScorerSupplier supplier;
private final Closeable onClose;
ScalarQuantizedCloseableRandomVectorScorerSupplier(
Closeable onClose, ScalarQuantizedRandomVectorScorerSupplier supplier) {
this.onClose = onClose;
this.supplier = supplier;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
return supplier.scorer(ord);
}
@Override
public void close() throws IOException {
onClose.close();
}
}
private static final class OffsetCorrectedQuantizedByteVectorValues
extends QuantizedByteVectorValues {
private final QuantizedByteVectorValues in;
private final VectorSimilarityFunction vectorSimilarityFunction;
private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
private OffsetCorrectedQuantizedByteVectorValues(
QuantizedByteVectorValues in,
VectorSimilarityFunction vectorSimilarityFunction,
ScalarQuantizer scalarQuantizer,
ScalarQuantizer oldScalarQuantizer) {
this.in = in;
this.vectorSimilarityFunction = vectorSimilarityFunction;
this.scalarQuantizer = scalarQuantizer;
this.oldScalarQuantizer = oldScalarQuantizer;
}
@Override
float getScoreCorrectionConstant() throws IOException {
return scalarQuantizer.recalculateCorrectiveOffset(
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
}
@Override
public int dimension() {
return in.dimension();
}
@Override
public int size() {
return in.size();
}
@Override
public byte[] vectorValue() throws IOException {
return in.vectorValue();
}
@Override
public int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
return in.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return in.advance(target);
}
}
}

View File

@ -0,0 +1,275 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.lucene.codecs.lucene90.IndexedDISI;
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.packed.DirectMonotonicReader;
/**
* Read the quantized vector values and their score correction values from the index input. This
* supports both iterated and random access.
*/
abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
implements RandomAccessQuantizedByteVectorValues {
protected final int dimension;
protected final int size;
protected final IndexInput slice;
protected final byte[] binaryValue;
protected final ByteBuffer byteBuffer;
protected final int byteSize;
protected int lastOrd = -1;
protected final float[] scoreCorrectionConstant = new float[1];
OffHeapQuantizedByteVectorValues(int dimension, int size, IndexInput slice) {
this.dimension = dimension;
this.size = size;
this.slice = slice;
this.byteSize = dimension + Float.BYTES;
byteBuffer = ByteBuffer.allocate(dimension);
binaryValue = byteBuffer.array();
}
@Override
public int dimension() {
return dimension;
}
@Override
public int size() {
return size;
}
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
if (lastOrd == targetOrd) {
return binaryValue;
}
slice.seek((long) targetOrd * byteSize);
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension);
slice.readFloats(scoreCorrectionConstant, 0, 1);
lastOrd = targetOrd;
return binaryValue;
}
@Override
public float getScoreCorrectionConstant() {
return scoreCorrectionConstant[0];
}
static OffHeapQuantizedByteVectorValues load(
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
long quantizedVectorDataOffset,
long quantizedVectorDataLength,
IndexInput vectorData)
throws IOException {
if (configuration.isEmpty()) {
return new EmptyOffHeapVectorValues(dimension);
}
IndexInput bytesSlice =
vectorData.slice(
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
if (configuration.isDense()) {
return new DenseOffHeapVectorValues(dimension, size, bytesSlice);
} else {
return new SparseOffHeapVectorValues(configuration, dimension, size, vectorData, bytesSlice);
}
}
abstract Bits getAcceptOrds(Bits acceptDocs);
static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
private int doc = -1;
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
super(dimension, size, slice);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(doc);
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
if (target >= size) {
return doc = NO_MORE_DOCS;
}
return doc = target;
}
@Override
public DenseOffHeapVectorValues copy() throws IOException {
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return acceptDocs;
}
}
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
private final DirectMonotonicReader ordToDoc;
private final IndexedDISI disi;
// dataIn was used to init a new IndexedDIS for #randomAccess()
private final IndexInput dataIn;
private final OrdToDocDISIReaderConfiguration configuration;
public SparseOffHeapVectorValues(
OrdToDocDISIReaderConfiguration configuration,
int dimension,
int size,
IndexInput dataIn,
IndexInput slice)
throws IOException {
super(dimension, size, slice);
this.configuration = configuration;
this.dataIn = dataIn;
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
this.disi = configuration.getIndexedDISI(dataIn);
}
@Override
public byte[] vectorValue() throws IOException {
return vectorValue(disi.index());
}
@Override
public int docID() {
return disi.docID();
}
@Override
public int nextDoc() throws IOException {
return disi.nextDoc();
}
@Override
public int advance(int target) throws IOException {
assert docID() < target;
return disi.advance(target);
}
@Override
public SparseOffHeapVectorValues copy() throws IOException {
return new SparseOffHeapVectorValues(configuration, dimension, size, dataIn, slice.clone());
}
@Override
public int ordToDoc(int ord) {
return (int) ordToDoc.get(ord);
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
if (acceptDocs == null) {
return null;
}
return new Bits() {
@Override
public boolean get(int index) {
return acceptDocs.get(ordToDoc(index));
}
@Override
public int length() {
return size;
}
};
}
}
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
public EmptyOffHeapVectorValues(int dimension) {
super(dimension, 0, null);
}
private int doc = -1;
@Override
public int dimension() {
return super.dimension();
}
@Override
public int size() {
return 0;
}
@Override
public byte[] vectorValue() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return doc;
}
@Override
public int nextDoc() throws IOException {
return advance(doc + 1);
}
@Override
public int advance(int target) throws IOException {
return doc = NO_MORE_DOCS;
}
@Override
public EmptyOffHeapVectorValues copy() throws IOException {
throw new UnsupportedOperationException();
}
@Override
public byte[] vectorValue(int targetOrd) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int ordToDoc(int ord) {
throw new UnsupportedOperationException();
}
@Override
Bits getAcceptOrds(Bits acceptDocs) {
return null;
}
}
}

View File

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.index.ByteVectorValues;
/**
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
* Scalar quantization scores.
*/
abstract class QuantizedByteVectorValues extends ByteVectorValues {
abstract float getScoreCorrectionConstant() throws IOException;
}

View File

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.Closeable;
import java.io.IOException;
import org.apache.lucene.util.Accountable;
import org.apache.lucene.util.ScalarQuantizer;
/** Quantized vector reader */
interface QuantizedVectorsReader extends Closeable, Accountable {
QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException;
ScalarQuantizer getQuantizationState(String fieldName);
}

View File

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
/**
* Random access values for <code>byte[]</code>, but also includes accessing the score correction
* constant for the current vector in the buffer.
*/
interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues<byte[]> {
float getScoreCorrectionConstant();
@Override
RandomAccessQuantizedByteVectorValues copy() throws IOException;
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
/** Quantized vector scorer */
final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
private static float quantizeQuery(
float[] query,
byte[] quantizedQuery,
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer) {
float[] processedQuery =
switch (similarityFunction) {
case EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> query;
case COSINE -> {
float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length);
VectorUtil.l2normalize(queryCopy);
yield queryCopy;
}
};
return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
}
private final byte[] quantizedQuery;
private final float queryOffset;
private final RandomAccessQuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
ScalarQuantizedRandomVectorScorer(
ScalarQuantizedVectorSimilarity similarityFunction,
RandomAccessQuantizedByteVectorValues values,
byte[] query,
float queryOffset) {
this.quantizedQuery = query;
this.queryOffset = queryOffset;
this.similarity = similarityFunction;
this.values = values;
}
ScalarQuantizedRandomVectorScorer(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values,
float[] query) {
byte[] quantizedQuery = new byte[query.length];
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
this.quantizedQuery = quantizedQuery;
this.queryOffset = correction;
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier());
this.values = values;
}
@Override
public float score(int node) throws IOException {
byte[] storedVectorValue = values.vectorValue(node);
float storedVectorCorrection = values.getScoreCorrectionConstant();
return similarity.score(
quantizedQuery, this.queryOffset, storedVectorValue, storedVectorCorrection);
}
}

View File

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.io.IOException;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.ScalarQuantizedVectorSimilarity;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
/** Quantized vector scorer supplier */
final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
private final RandomAccessQuantizedByteVectorValues values;
private final ScalarQuantizedVectorSimilarity similarity;
ScalarQuantizedRandomVectorScorerSupplier(
VectorSimilarityFunction similarityFunction,
ScalarQuantizer scalarQuantizer,
RandomAccessQuantizedByteVectorValues values) {
this.similarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, scalarQuantizer.getConstantMultiplier());
this.values = values;
}
@Override
public RandomVectorScorer scorer(int ord) throws IOException {
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
final byte[] queryVector = values.vectorValue(ord);
final float queryOffset = values.getScoreCorrectionConstant();
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
}
}

View File

@ -180,7 +180,7 @@
* of files, recording dimensionally indexed fields, to enable fast numeric range filtering * of files, recording dimensionally indexed fields, to enable fast numeric range filtering
* and large numeric values like BigInteger and BigDecimal (1D) and geographic shape * and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
* intersection (2D, 3D). * intersection (2D, 3D).
* <li>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}. The * <li>{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The
* vector format stores numeric vectors in a format optimized for random access and * vector format stores numeric vectors in a format optimized for random access and
* computation, supporting high-dimensional nearest-neighbor search. * computation, supporting high-dimensional nearest-neighbor search.
* </ul> * </ul>
@ -310,10 +310,11 @@
* <td>Holds indexed points</td> * <td>Holds indexed points</td>
* </tr> * </tr>
* <tr> * <tr>
* <td>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td> * <td>{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}</td>
* <td>.vec, .vem</td> * <td>.vec, .vem, .veq, vex</td>
* <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data, and * <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data,
* <code>.vem</code> the vector metadata</td> * <code>.vem</code> the vector metadata, <code>.veq</code> the quantized vector data, and <code>.vex</code> the
* hnsw graph data.</td>
* </tr> * </tr>
* </table> * </table>
* *
@ -408,6 +409,8 @@
* <li>In version 9.5, HNSW graph connections were changed to be delta-encoded with vints. * <li>In version 9.5, HNSW graph connections were changed to be delta-encoded with vints.
* Additionally, metadata file size improvements were made by delta-encoding nodes by graph * Additionally, metadata file size improvements were made by delta-encoding nodes by graph
* layer and not writing the node ids for the zeroth layer. * layer and not writing the node ids for the zeroth layer.
* <li>In version 9.9, Vector scalar quantization support was added. Allowing the HNSW vector
* format to utilize int8 quantized vectors for float32 vector search.
* </ul> * </ul>
* *
* <a id="Limitations"></a> * <a id="Limitations"></a>

View File

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* Calculates and adjust the scores correctly for quantized vectors given the scalar quantization
* parameters
*/
public interface ScalarQuantizedVectorSimilarity {
/**
* Creates a {@link ScalarQuantizedVectorSimilarity} from a {@link VectorSimilarityFunction} and
* the constant multiplier used for quantization.
*
* @param sim similarity function
* @param constMultiplier constant multiplier used for quantization
* @return a {@link ScalarQuantizedVectorSimilarity} that applies the appropriate corrections
*/
static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
VectorSimilarityFunction sim, float constMultiplier) {
return switch (sim) {
case EUCLIDEAN -> new Euclidean(constMultiplier);
case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier);
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier);
};
}
float score(byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset);
/** Calculates euclidean distance on quantized vectors, applying the appropriate corrections */
class Euclidean implements ScalarQuantizedVectorSimilarity {
private final float constMultiplier;
public Euclidean(float constMultiplier) {
this.constMultiplier = constMultiplier;
}
@Override
public float score(
byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset) {
int squareDistance = VectorUtil.squareDistance(storedVector, queryVector);
float adjustedDistance = squareDistance * constMultiplier;
return 1 / (1f + adjustedDistance);
}
}
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
class DotProduct implements ScalarQuantizedVectorSimilarity {
private final float constMultiplier;
public DotProduct(float constMultiplier) {
this.constMultiplier = constMultiplier;
}
@Override
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return (1 + adjustedDistance) / 2;
}
}
/** Calculates max inner product on quantized vectors, applying the appropriate corrections */
class MaximumInnerProduct implements ScalarQuantizedVectorSimilarity {
private final float constMultiplier;
public MaximumInnerProduct(float constMultiplier) {
this.constMultiplier = constMultiplier;
}
@Override
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return scaleMaxInnerProductScore(adjustedDistance);
}
}
}

View File

@ -0,0 +1,317 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import java.util.stream.IntStream;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
/**
* Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation.
* Scalar quantization works by first calculating the quantiles of the float vector values. The
* quantiles are calculated using the configured quantile/confidence interval. The [minQuantile,
* maxQuantile] are then used to scale the values into the range [0, 127] and bucketed into the
* nearest byte values.
*
* <h2>How Scalar Quantization Works</h2>
*
* <p>The basic mathematical equations behind this are fairly straight forward. Given a float vector
* `v` and a quantile `q` we can calculate the quantiles of the vector values [minQuantile,
* maxQuantile].
*
* <pre class="prettyprint">
* byte = (float - minQuantile) * 127/(maxQuantile - minQuantile)
* float = (maxQuantile - minQuantile)/127 * byte + minQuantile
* </pre>
*
* <p>This then means to multiply two float values together (e.g. dot_product) we can do the
* following:
*
* <pre class="prettyprint">
* float1 * float2 ~= (byte1 * (maxQuantile - minQuantile)/127 + minQuantile) * (byte2 * (maxQuantile - minQuantile)/127 + minQuantile)
* float1 * float2 ~= (byte1 * byte2 * (maxQuantile - minQuantile)^2)/(127^2) + (byte1 * minQuantile * (maxQuantile - minQuantile)/127) + (byte2 * minQuantile * (maxQuantile - minQuantile)/127) + minQuantile^2
* let alpha = (maxQuantile - minQuantile)/127
* float1 * float2 ~= (byte1 * byte2 * alpha^2) + (byte1 * minQuantile * alpha) + (byte2 * minQuantile * alpha) + minQuantile^2
* </pre>
*
* <p>The expansion for square distance is much simpler:
*
* <pre class="prettyprint">
* square_distance = (float1 - float2)^2
* (float1 - float2)^2 ~= (byte1 * alpha + minQuantile - byte2 * alpha - minQuantile)^2
* = (alpha*byte1 + minQuantile)^2 + (alpha*byte2 + minQuantile)^2 - 2*(alpha*byte1 + minQuantile)(alpha*byte2 + minQuantile)
* this can be simplified to:
* = alpha^2 (byte1 - byte2)^2
* </pre>
*/
public class ScalarQuantizer {
public static final int SCALAR_QUANTIZATION_SAMPLE_SIZE = 25_000;
private final float alpha;
private final float scale;
private final float minQuantile, maxQuantile, configuredQuantile;
/**
* @param minQuantile the lower quantile of the distribution
* @param maxQuantile the upper quantile of the distribution
* @param configuredQuantile The configured quantile/confidence interval used to calculate the
* quantiles.
*/
public ScalarQuantizer(float minQuantile, float maxQuantile, float configuredQuantile) {
assert maxQuantile >= minQuantile;
this.minQuantile = minQuantile;
this.maxQuantile = maxQuantile;
this.scale = 127f / (maxQuantile - minQuantile);
this.alpha = (maxQuantile - minQuantile) / 127f;
this.configuredQuantile = configuredQuantile;
}
/**
* Quantize a float vector into a byte vector
*
* @param src the source vector
* @param dest the destination vector
* @param similarityFunction the similarity function used to calculate the quantile
* @return the corrective offset that needs to be applied to the score
*/
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
assert src.length == dest.length;
float correctiveOffset = 0f;
for (int i = 0; i < src.length; i++) {
float v = src[i];
// Make sure the value is within the quantile range, cutting off the tails
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
// minQuantile)
float dx = Math.max(minQuantile, Math.min(maxQuantile, src[i])) - minQuantile;
// Scale the value to the range [0, 127], this is our quantized value
// scale = 127/(maxQuantile - minQuantile)
float dxs = scale * dx;
// We multiply by `alpha` here to get the quantized value back into the original range
// to aid in calculating the corrective offset
float dxq = Math.round(dxs) * alpha;
// Calculate the corrective offset that needs to be applied to the score
// in addition to the `byte * minQuantile * alpha` term in the equation
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
// will be rounded to the nearest whole number and lose some accuracy
// Additionally, we account for the global correction of `minQuantile^2` in the equation
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
dest[i] = (byte) Math.round(dxs);
}
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0;
}
return correctiveOffset;
}
/**
* Recalculate the old score corrective value given new current quantiles
*
* @param quantizedVector the old vector
* @param oldQuantizer the old quantizer
* @param similarityFunction the similarity function used to calculate the quantile
* @return the new offset
*/
public float recalculateCorrectiveOffset(
byte[] quantizedVector,
ScalarQuantizer oldQuantizer,
VectorSimilarityFunction similarityFunction) {
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
return 0f;
}
float correctiveOffset = 0f;
for (int i = 0; i < quantizedVector.length; i++) {
// dequantize the old value in order to recalculate the corrective offset
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
float dx = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
float dxs = scale * dx;
float dxq = Math.round(dxs) * alpha;
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
}
return correctiveOffset;
}
/**
* Dequantize a byte vector into a float vector
*
* @param src the source vector
* @param dest the destination vector
*/
public void deQuantize(byte[] src, float[] dest) {
assert src.length == dest.length;
for (int i = 0; i < src.length; i++) {
dest[i] = (alpha * src[i]) + minQuantile;
}
}
public float getLowerQuantile() {
return minQuantile;
}
public float getUpperQuantile() {
return maxQuantile;
}
public float getConfiguredQuantile() {
return configuredQuantile;
}
public float getConstantMultiplier() {
return alpha * alpha;
}
@Override
public String toString() {
return "ScalarQuantizer{"
+ "minQuantile="
+ minQuantile
+ ", maxQuantile="
+ maxQuantile
+ ", configuredQuantile="
+ configuredQuantile
+ '}';
}
private static final Random random = new Random(42);
/**
* This will read the float vector values and calculate the quantiles. If the number of float
* vectors is less than {@link #SCALAR_QUANTIZATION_SAMPLE_SIZE} then all the values will be read
* and the quantiles calculated. If the number of float vectors is greater than {@link
* #SCALAR_QUANTIZATION_SAMPLE_SIZE} then a random sample of {@link
* #SCALAR_QUANTIZATION_SAMPLE_SIZE} will be read and the quantiles calculated.
*
* @param floatVectorValues the float vector values from which to calculate the quantiles
* @param quantile the quantile/confidence interval used to calculate the quantiles
* @return A new {@link ScalarQuantizer} instance
* @throws IOException if there is an error reading the float vector values
*/
public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float quantile)
throws IOException {
assert 0.9f <= quantile && quantile <= 1f;
if (floatVectorValues.size() == 0) {
return new ScalarQuantizer(0f, 0f, quantile);
}
if (quantile == 1f) {
float min = Float.POSITIVE_INFINITY;
float max = Float.NEGATIVE_INFINITY;
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
for (float v : floatVectorValues.vectorValue()) {
min = Math.min(min, v);
max = Math.max(max, v);
}
}
return new ScalarQuantizer(min, max, quantile);
}
int dim = floatVectorValues.dimension();
if (floatVectorValues.size() < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
int copyOffset = 0;
float[] values = new float[floatVectorValues.size() * dim];
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
float[] floatVector = floatVectorValues.vectorValue();
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
copyOffset += dim;
}
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile);
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile);
}
int numFloatVecs = floatVectorValues.size();
// Reservoir sample the vector ordinals we want to read
float[] values = new float[SCALAR_QUANTIZATION_SAMPLE_SIZE * dim];
int[] vectorsToTake = IntStream.range(0, SCALAR_QUANTIZATION_SAMPLE_SIZE).toArray();
for (int i = SCALAR_QUANTIZATION_SAMPLE_SIZE; i < numFloatVecs; i++) {
int j = random.nextInt(i + 1);
if (j < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
vectorsToTake[j] = i;
}
}
Arrays.sort(vectorsToTake);
int copyOffset = 0;
int index = 0;
for (int i : vectorsToTake) {
while (index <= i) {
// We cannot use `advance(docId)` as MergedVectorValues does not support it
floatVectorValues.nextDoc();
index++;
}
assert floatVectorValues.docID() != NO_MORE_DOCS;
float[] floatVector = floatVectorValues.vectorValue();
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
copyOffset += dim;
}
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile);
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile);
}
/**
* Takes an array of floats, sorted or not, and returns a minimum and maximum value. These values
* are such that they reside on the `(1 - quantile)/2` and `quantile/2` percentiles. Example:
* providing floats `[0..100]` and asking for `90` quantiles will return `5` and `95`.
*
* @param arr array of floats
* @param quantileFloat the configured quantile
* @return lower and upper quantile values
*/
static float[] getUpperAndLowerQuantile(float[] arr, float quantileFloat) {
assert 0.9f <= quantileFloat && quantileFloat <= 1f;
int selectorIndex = (int) (arr.length * (1f - quantileFloat) / 2f + 0.5f);
if (selectorIndex > 0) {
Selector selector = new FloatSelector(arr);
selector.select(0, arr.length, arr.length - selectorIndex);
selector.select(0, arr.length - selectorIndex, selectorIndex);
}
float min = Float.POSITIVE_INFINITY;
float max = Float.NEGATIVE_INFINITY;
for (int i = selectorIndex; i < arr.length - selectorIndex; i++) {
min = Math.min(arr[i], min);
max = Math.max(arr[i], max);
}
return new float[] {min, max};
}
private static class FloatSelector extends IntroSelector {
float pivot = Float.NaN;
private final float[] arr;
private FloatSelector(float[] arr) {
this.arr = arr;
}
@Override
protected void setPivot(int i) {
pivot = arr[i];
}
@Override
protected int comparePivot(int j) {
return Float.compare(pivot, arr[j]);
}
@Override
protected void swap(int i, int j) {
final float tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
}
}

View File

@ -0,0 +1,27 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util.hnsw;
import java.io.Closeable;
/**
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
* close after use
*/
public interface CloseableRandomVectorScorerSupplier
extends Closeable, RandomVectorScorerSupplier {}

View File

@ -13,4 +13,4 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat

View File

@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.ScalarQuantizer;
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswVectorsFormat(
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
new Lucene99ScalarQuantizedVectorsFormat());
}
};
}
public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity();
int dim = random().nextInt(64) + 1;
List<float[]> vectors = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) {
vectors.add(randomVector(dim));
}
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {
expectedVectors[i] = new byte[dim];
expectedCorrections[i] =
scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction);
}
float[] randomlyReusedVector = new float[dim];
try (Directory dir = newDirectory();
IndexWriter w =
new IndexWriter(
dir,
new IndexWriterConfig()
.setMaxBufferedDocs(numVectors + 1)
.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH)
.setMergePolicy(NoMergePolicy.INSTANCE))) {
for (int i = 0; i < numVectors; i++) {
Document doc = new Document();
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
// reference
final float[] v;
if (random().nextBoolean()) {
System.arraycopy(vectors.get(i), 0, randomlyReusedVector, 0, dim);
v = randomlyReusedVector;
} else {
v = vectors.get(i);
}
doc.add(new KnnFloatVectorField("f", v, similarityFunction));
w.addDocument(doc);
}
w.commit();
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
if (r instanceof CodecReader codecReader) {
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
knnVectorsReader = fieldsReader.getFieldReader("f");
}
if (knnVectorsReader instanceof Lucene99HnswVectorsReader hnswReader) {
assertNotNull(hnswReader.getQuantizationState("f"));
QuantizedByteVectorValues quantizedByteVectorValues =
hnswReader.getQuantizedVectorValues("f");
int docId = -1;
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
byte[] vector = quantizedByteVectorValues.vectorValue();
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
for (int i = 0; i < dim; i++) {
assertEquals(vector[i], expectedVectors[docId][i]);
}
assertEquals(offset, expectedCorrections[docId], 0.00001f);
}
} else {
fail("reader is not Lucene99HnswVectorsReader");
}
} else {
fail("reader is not CodecReader");
}
}
}
}
public void testToString() {
FilterCodec customCodec =
new FilterCodec("foo", Codec.getDefault()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new Lucene99HnswVectorsFormat(
10, 20, new Lucene99ScalarQuantizedVectorsFormat(0.9f));
}
};
String expectedString =
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9))";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
}
}

View File

@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
package org.apache.lucene.codecs.lucene95; package org.apache.lucene.codecs.lucene99;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.FilterCodec;
@ -22,7 +22,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase { public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override @Override
protected Codec getCodec() { protected Codec getCodec() {
return TestUtil.getDefaultCodec(); return TestUtil.getDefaultCodec();
@ -33,20 +33,22 @@ public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
new FilterCodec("foo", Codec.getDefault()) { new FilterCodec("foo", Codec.getDefault()) {
@Override @Override
public KnnVectorsFormat knnVectorsFormat() { public KnnVectorsFormat knnVectorsFormat() {
return new Lucene95HnswVectorsFormat(10, 20); return new Lucene99HnswVectorsFormat(10, 20, null);
} }
}; };
String expectedString = String expectedString =
"Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn=10, beamWidth=20)"; "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=none)";
assertEquals(expectedString, customCodec.knnVectorsFormat().toString()); assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
} }
public void testLimits() { public void testLimits() {
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(-1, 20)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(0, 20)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, 0)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, -1)); expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1, null));
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(512 + 1, 20)); expectThrows(
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, 3201)); IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20, null));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201, null));
} }
} }

View File

@ -0,0 +1,79 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.codecs.lucene99;
import java.util.ArrayList;
import java.util.List;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorsFormat extends LuceneTestCase {
public void testDefaultQuantile() {
float defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(99);
assertEquals(0.99f, defaultQuantile, 1e-5);
defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(1);
assertEquals(0.9f, defaultQuantile, 1e-5);
defaultQuantile =
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(Integer.MAX_VALUE - 2);
assertEquals(1.0f, defaultQuantile, 1e-5);
}
public void testLimits() {
expectThrows(
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(0.89f));
expectThrows(
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(1.1f));
}
public void testQuantileMergeWithMissing() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
quantiles.add(null);
List<Integer> segmentSizes = List.of(1, 1, 1, 1);
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f));
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(List.of(), List.of(), 0.1f));
}
public void testQuantileMerge() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
List<Integer> segmentSizes = List.of(1, 1, 1);
ScalarQuantizer merged =
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
assertEquals(0.2f, merged.getLowerQuantile(), 1e-5);
assertEquals(0.3f, merged.getUpperQuantile(), 1e-5);
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
}
public void testQuantileMergeWithDifferentSegmentSizes() {
List<ScalarQuantizer> quantiles = new ArrayList<>();
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
List<Integer> segmentSizes = List.of(1, 2, 3);
ScalarQuantizer merged =
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
assertEquals(0.2333333f, merged.getLowerQuantile(), 1e-5);
assertEquals(0.3333333f, merged.getUpperQuantile(), 1e-5);
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
}
}

View File

@ -29,7 +29,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.KnnVectorsWriter; import org.apache.lucene.codecs.KnnVectorsWriter;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.document.KnnFloatVectorField;
@ -170,8 +170,8 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
try (Directory directory = newDirectory()) { try (Directory directory = newDirectory()) {
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random())); IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
KnnVectorsFormat format1 = KnnVectorsFormat format1 =
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100)); new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100, null));
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100); KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100, null);
iwc.setCodec( iwc.setCodec(
new AssertingCodec() { new AssertingCodec() {
@Override @Override

View File

@ -16,6 +16,7 @@
*/ */
package org.apache.lucene.index; package org.apache.lucene.index;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween; import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed; import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
@ -31,8 +32,9 @@ import java.util.concurrent.CountDownLatch;
import org.apache.lucene.codecs.Codec; import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
@ -78,23 +80,15 @@ public class TestKnnGraph extends LuceneTestCase {
M = random().nextInt(256) + 3; M = random().nextInt(256) + 3;
} }
codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override
public KnnVectorsFormat knnVectorsFormat() {
return new PerFieldKnnVectorsFormat() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
}
};
}
};
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1; int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
similarityFunction = VectorSimilarityFunction.values()[similarity]; similarityFunction = VectorSimilarityFunction.values()[similarity];
vectorEncoding = randomVectorEncoding(); vectorEncoding = randomVectorEncoding();
Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat =
vectorEncoding.equals(VectorEncoding.FLOAT32) && randomBoolean()
? new Lucene99ScalarQuantizedVectorsFormat(1f)
: null;
codec = codec =
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) { new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
@Override @Override
@ -102,13 +96,14 @@ public class TestKnnGraph extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); return new Lucene99HnswVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, scalarQuantizedVectorsFormat);
} }
}; };
} }
}; };
if (vectorEncoding == VectorEncoding.FLOAT32) { if (vectorEncoding == VectorEncoding.FLOAT32 && scalarQuantizedVectorsFormat == null) {
float32Codec = codec; float32Codec = codec;
} else { } else {
float32Codec = float32Codec =
@ -118,7 +113,8 @@ public class TestKnnGraph extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH); return new Lucene99HnswVectorsFormat(
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, null);
} }
}; };
} }
@ -409,8 +405,8 @@ public class TestKnnGraph extends LuceneTestCase {
if (perFieldReader == null) { if (perFieldReader == null) {
continue; continue;
} }
Lucene95HnswVectorsReader vectorReader = Lucene99HnswVectorsReader vectorReader =
(Lucene95HnswVectorsReader) perFieldReader.getFieldReader(vectorField); (Lucene99HnswVectorsReader) perFieldReader.getFieldReader(vectorField);
if (vectorReader == null) { if (vectorReader == null) {
continue; continue;
} }

View File

@ -0,0 +1,216 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import static org.apache.lucene.util.TestScalarQuantizer.fromFloats;
import static org.apache.lucene.util.TestScalarQuantizer.randomFloatArray;
import static org.apache.lucene.util.TestScalarQuantizer.randomFloats;
import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
public void testToEuclidean() throws IOException {
int dims = 128;
int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.EUCLIDEAN, scalarQuantizer.getConstantMultiplier());
assertQuantizedScores(
floats,
quantized,
offsets,
query,
error,
VectorSimilarityFunction.EUCLIDEAN,
quantizedSimilarity,
scalarQuantizer);
}
}
public void testToCosine() throws IOException {
int dims = 128;
int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectorsNormalized(
scalarQuantizer, floats, quantized, VectorSimilarityFunction.COSINE);
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
VectorUtil.l2normalize(query);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.COSINE, scalarQuantizer.getConstantMultiplier());
assertQuantizedScores(
floats,
quantized,
offsets,
query,
error,
VectorSimilarityFunction.COSINE,
quantizedSimilarity,
scalarQuantizer);
}
}
public void testToDotProduct() throws IOException {
int dims = 128;
int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims);
for (float[] fs : floats) {
VectorUtil.l2normalize(fs);
}
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
float[] query = randomFloatArray(dims);
VectorUtil.l2normalize(query);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.DOT_PRODUCT, scalarQuantizer.getConstantMultiplier());
assertQuantizedScores(
floats,
quantized,
offsets,
query,
error,
VectorSimilarityFunction.DOT_PRODUCT,
quantizedSimilarity,
scalarQuantizer);
}
}
public void testToMaxInnerProduct() throws IOException {
int dims = 128;
int numVecs = 100;
float[][] floats = randomFloats(numVecs, dims);
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
float error = Math.max((100 - quantile) * 0.5f, 0.5f);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
byte[][] quantized = new byte[floats.length][];
float[] offsets =
quantizeVectors(
scalarQuantizer, floats, quantized, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
float[] query = randomFloatArray(dims);
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
scalarQuantizer.getConstantMultiplier());
assertQuantizedScores(
floats,
quantized,
offsets,
query,
error,
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
quantizedSimilarity,
scalarQuantizer);
}
}
private void assertQuantizedScores(
float[][] floats,
byte[][] quantized,
float[] storedOffsets,
float[] query,
float error,
VectorSimilarityFunction similarityFunction,
ScalarQuantizedVectorSimilarity quantizedSimilarity,
ScalarQuantizer scalarQuantizer) {
for (int i = 0; i < floats.length; i++) {
float storedOffset = storedOffsets[i];
byte[] quantizedQuery = new byte[query.length];
float queryOffset = scalarQuantizer.quantize(query, quantizedQuery, similarityFunction);
float original = similarityFunction.compare(query, floats[i]);
float quantizedScore =
quantizedSimilarity.score(quantizedQuery, queryOffset, quantized[i], storedOffset);
assertEquals("Not within acceptable error [" + error + "]", original, quantizedScore, error);
}
}
private static float[] quantizeVectors(
ScalarQuantizer scalarQuantizer,
float[][] floats,
byte[][] quantized,
VectorSimilarityFunction similarityFunction) {
int i = 0;
float[] offsets = new float[floats.length];
for (float[] v : floats) {
quantized[i] = new byte[v.length];
offsets[i] = scalarQuantizer.quantize(v, quantized[i], similarityFunction);
++i;
}
return offsets;
}
private static float[] quantizeVectorsNormalized(
ScalarQuantizer scalarQuantizer,
float[][] floats,
byte[][] quantized,
VectorSimilarityFunction similarityFunction) {
int i = 0;
float[] offsets = new float[floats.length];
for (float[] f : floats) {
float[] v = ArrayUtil.copyOfSubArray(f, 0, f.length);
VectorUtil.l2normalize(v);
quantized[i] = new byte[v.length];
offsets[i] = scalarQuantizer.quantize(v, quantized[i], similarityFunction);
++i;
}
return offsets;
}
private static FloatVectorValues fromFloatsNormalized(float[][] floats) {
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats) {
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= floats.length) {
throw new IOException("Current doc not set or too many iterations");
}
float[] v = ArrayUtil.copyOfSubArray(floats[curDoc], 0, floats[curDoc].length);
VectorUtil.l2normalize(v);
return v;
}
};
}
}

View File

@ -0,0 +1,149 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.util;
import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestScalarQuantizer extends LuceneTestCase {
public void testQuantizeAndDeQuantize() throws IOException {
int dims = 128;
int numVecs = 100;
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
float[][] floats = randomFloats(numVecs, dims);
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1);
float[] dequantized = new float[dims];
byte[] quantized = new byte[dims];
byte[] requantized = new byte[dims];
for (int i = 0; i < numVecs; i++) {
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
scalarQuantizer.deQuantize(quantized, dequantized);
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
for (int j = 0; j < dims; j++) {
assertEquals(dequantized[j], floats[i][j], 0.02);
assertEquals(quantized[j], requantized[j]);
}
}
}
public void testQuantiles() {
float[] percs = new float[1000];
for (int i = 0; i < 1000; i++) {
percs[i] = (float) i;
}
shuffleArray(percs);
float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.9f);
assertEquals(50f, upperAndLower[0], 1e-7);
assertEquals(949f, upperAndLower[1], 1e-7);
shuffleArray(percs);
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.95f);
assertEquals(25f, upperAndLower[0], 1e-7);
assertEquals(974f, upperAndLower[1], 1e-7);
shuffleArray(percs);
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.99f);
assertEquals(5f, upperAndLower[0], 1e-7);
assertEquals(994f, upperAndLower[1], 1e-7);
}
public void testEdgeCase() {
float[] upperAndLower =
ScalarQuantizer.getUpperAndLowerQuantile(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, 0.9f);
assertEquals(1f, upperAndLower[0], 1e-7f);
assertEquals(1f, upperAndLower[1], 1e-7f);
}
static void shuffleArray(float[] ar) {
for (int i = ar.length - 1; i > 0; i--) {
int index = random().nextInt(i + 1);
float a = ar[index];
ar[index] = ar[i];
ar[i] = a;
}
}
static float[] randomFloatArray(int dims) {
float[] arr = new float[dims];
for (int j = 0; j < dims; j++) {
arr[j] = random().nextFloat(-1, 1);
}
return arr;
}
static float[][] randomFloats(int num, int dims) {
float[][] floats = new float[num][];
for (int i = 0; i < num; i++) {
floats[i] = randomFloatArray(dims);
}
return floats;
}
static FloatVectorValues fromFloats(float[][] floats) {
return new TestSimpleFloatVectorValues(floats);
}
static class TestSimpleFloatVectorValues extends FloatVectorValues {
protected final float[][] floats;
protected int curDoc = -1;
TestSimpleFloatVectorValues(float[][] values) {
this.floats = values;
}
@Override
public int dimension() {
return floats[0].length;
}
@Override
public int size() {
return floats.length;
}
@Override
public float[] vectorValue() throws IOException {
if (curDoc == -1 || curDoc >= floats.length) {
throw new IOException("Current doc not set or too many iterations");
}
return floats[curDoc];
}
@Override
public int docID() {
if (curDoc >= floats.length) {
return NO_MORE_DOCS;
}
return curDoc;
}
@Override
public int nextDoc() throws IOException {
curDoc++;
return docID();
}
@Override
public int advance(int target) throws IOException {
curDoc = target;
return docID();
}
}
}

View File

@ -40,8 +40,8 @@ import java.util.concurrent.TimeoutException;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
@ -165,7 +165,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, beamWidth); return new Lucene99HnswVectorsFormat(M, beamWidth, null);
} }
}; };
} }
@ -237,7 +237,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, beamWidth); return new Lucene99HnswVectorsFormat(M, beamWidth, null);
} }
}; };
} }
@ -266,7 +266,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
assertEquals(indexedDoc, ctx.reader().numDocs()); assertEquals(indexedDoc, ctx.reader().numDocs());
assertVectorsEqual(v3, values); assertVectorsEqual(v3, values);
HnswGraph graphValues = HnswGraph graphValues =
((Lucene95HnswVectorsReader) ((Lucene99HnswVectorsReader)
((PerFieldKnnVectorsFormat.FieldsReader) ((PerFieldKnnVectorsFormat.FieldsReader)
((CodecReader) ctx.reader()).getVectorReader()) ((CodecReader) ctx.reader()).getVectorReader())
.getFieldReader("field")) .getFieldReader("field"))
@ -298,7 +298,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, beamWidth); return new Lucene99HnswVectorsFormat(M, beamWidth, null);
} }
}; };
} }
@ -312,7 +312,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
return new PerFieldKnnVectorsFormat() { return new PerFieldKnnVectorsFormat() {
@Override @Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene95HnswVectorsFormat(M, beamWidth); return new Lucene99HnswVectorsFormat(M, beamWidth, null);
} }
}; };
} }

View File

@ -18,7 +18,7 @@ package org.apache.lucene.tests.codecs.vector;
import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.FilterCodec;
import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat; import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
/** /**
@ -32,12 +32,12 @@ public class ConfigurableMCodec extends FilterCodec {
public ConfigurableMCodec() { public ConfigurableMCodec() {
super("ConfigurableMCodec", TestUtil.getDefaultCodec()); super("ConfigurableMCodec", TestUtil.getDefaultCodec());
knnVectorsFormat = new Lucene95HnswVectorsFormat(128, 100); knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100, null);
} }
public ConfigurableMCodec(int maxConn) { public ConfigurableMCodec(int maxConn) {
super("ConfigurableMCodec", TestUtil.getDefaultCodec()); super("ConfigurableMCodec", TestUtil.getDefaultCodec());
knnVectorsFormat = new Lucene95HnswVectorsFormat(maxConn, 100); knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100, null);
} }
@Override @Override

View File

@ -712,7 +712,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
} }
} }
private VectorSimilarityFunction randomSimilarity() { protected VectorSimilarityFunction randomSimilarity() {
return VectorSimilarityFunction.values()[ return VectorSimilarityFunction.values()[
random().nextInt(VectorSimilarityFunction.values().length)]; random().nextInt(VectorSimilarityFunction.values().length)];
} }
@ -1221,7 +1221,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
iw.updateDocument(idTerm, doc); iw.updateDocument(idTerm, doc);
} }
private float[] randomVector(int dim) { protected float[] randomVector(int dim) {
assert dim > 0; assert dim > 0;
float[] v = new float[dim]; float[] v = new float[dim];
double squareSum = 0.0; double squareSum = 0.0;
@ -1409,6 +1409,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
} }
} }
assertEquals( assertEquals(
"encoding=" + vectorEncoding,
fieldValuesCheckSum, fieldValuesCheckSum,
checksum, checksum,
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5); vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);

View File

@ -55,8 +55,8 @@ import org.apache.lucene.codecs.PostingsFormat;
import org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat; import org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat; import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
import org.apache.lucene.codecs.lucene99.Lucene99Codec; import org.apache.lucene.codecs.lucene99.Lucene99Codec;
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat; import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat; import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
import org.apache.lucene.document.BinaryDocValuesField; import org.apache.lucene.document.BinaryDocValuesField;
@ -1322,7 +1322,7 @@ public final class TestUtil {
* Lucene. * Lucene.
*/ */
public static KnnVectorsFormat getDefaultKnnVectorsFormat() { public static KnnVectorsFormat getDefaultKnnVectorsFormat() {
return new Lucene95HnswVectorsFormat(); return new Lucene99HnswVectorsFormat();
} }
public static boolean anyFilesExceptWriteLock(Directory dir) throws IOException { public static boolean anyFilesExceptWriteLock(Directory dir) throws IOException {