mirror of https://github.com/apache/lucene.git
Add new int8 scalar quantization to HNSW codec (#12582)
Adds new int8 scalar quantization for HNSW codec. This uses a new lucene9.9 format and auto quantizes floating point vectors into bytes on flush and merge.
This commit is contained in:
parent
e5b55761d0
commit
f2bf5339e5
|
@ -47,7 +47,8 @@ module org.apache.lucene.backward_codecs {
|
|||
org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsFormat,
|
||||
org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat,
|
||||
org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat,
|
||||
org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat;
|
||||
org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat,
|
||||
org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
provides org.apache.lucene.codecs.Codec with
|
||||
org.apache.lucene.backward_codecs.lucene80.Lucene80Codec,
|
||||
org.apache.lucene.backward_codecs.lucene84.Lucene84Codec,
|
||||
|
|
|
@ -40,7 +40,6 @@ import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
|
|||
import org.apache.lucene.codecs.lucene90.Lucene90StoredFieldsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90TermVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
|
||||
|
@ -145,7 +144,7 @@ public class Lucene95Codec extends Codec {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final SegmentInfoFormat segmentInfoFormat() {
|
||||
public SegmentInfoFormat segmentInfoFormat() {
|
||||
return segmentInfosFormat;
|
||||
}
|
||||
|
||||
|
@ -165,7 +164,7 @@ public class Lucene95Codec extends Codec {
|
|||
}
|
||||
|
||||
@Override
|
||||
public final KnnVectorsFormat knnVectorsFormat() {
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return knnVectorsFormat;
|
||||
}
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
|
@ -96,7 +96,7 @@ import org.apache.lucene.util.hnsw.HnswGraph;
|
|||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
||||
public class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
||||
|
||||
static final String META_CODEC_NAME = "Lucene95HnswVectorsFormatMeta";
|
||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene95HnswVectorsFormatData";
|
||||
|
@ -105,8 +105,8 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
|||
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||
|
||||
public static final int VERSION_START = 0;
|
||||
public static final int VERSION_CURRENT = VERSION_START;
|
||||
static final int VERSION_START = 0;
|
||||
static final int VERSION_CURRENT = 1;
|
||||
|
||||
/**
|
||||
* A maximum configurable maximum max conn.
|
||||
|
@ -137,14 +137,14 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
|||
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
|
||||
* {@link Lucene95HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
|
||||
*/
|
||||
private final int maxConn;
|
||||
final int maxConn;
|
||||
|
||||
/**
|
||||
* The number of candidate neighbors to track while searching the graph for each newly inserted
|
||||
* node. Defaults to to {@link Lucene95HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
|
||||
* HnswGraph} for details.
|
||||
*/
|
||||
private final int beamWidth;
|
||||
final int beamWidth;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene95HnswVectorsFormat() {
|
||||
|
@ -179,7 +179,7 @@ public final class Lucene95HnswVectorsFormat extends KnnVectorsFormat {
|
|||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene95HnswVectorsWriter(state, maxConn, beamWidth);
|
||||
throw new UnsupportedOperationException("Old codecs may only be used for reading");
|
||||
}
|
||||
|
||||
@Override
|
|
@ -15,7 +15,7 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
|
@ -26,6 +26,9 @@ import java.util.Map;
|
|||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
|
@ -180,8 +180,8 @@
|
|||
* of files, recording dimensionally indexed fields, to enable fast numeric range filtering
|
||||
* and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
|
||||
* intersection (2D, 3D).
|
||||
* <li>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}. The
|
||||
* vector format stores numeric vectors in a format optimized for random access and
|
||||
* <li>{@link org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat Vector values}.
|
||||
* The vector format stores numeric vectors in a format optimized for random access and
|
||||
* computation, supporting high-dimensional nearest-neighbor search.
|
||||
* </ul>
|
||||
*
|
||||
|
@ -310,7 +310,7 @@
|
|||
* <td>Holds indexed points</td>
|
||||
* </tr>
|
||||
* <tr>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td>
|
||||
* <td>{@link org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td>
|
||||
* <td>.vec, .vem</td>
|
||||
* <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data, and
|
||||
* <code>.vem</code> the vector metadata</td>
|
||||
|
|
|
@ -17,3 +17,4 @@ org.apache.lucene.backward_codecs.lucene90.Lucene90HnswVectorsFormat
|
|||
org.apache.lucene.backward_codecs.lucene91.Lucene91HnswVectorsFormat
|
||||
org.apache.lucene.backward_codecs.lucene92.Lucene92HnswVectorsFormat
|
||||
org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat
|
||||
org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat
|
||||
|
|
|
@ -15,9 +15,9 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||
import static org.apache.lucene.backward_codecs.lucene95.Lucene95HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
|
@ -29,14 +29,33 @@ import java.util.List;
|
|||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.*;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.*;
|
||||
import org.apache.lucene.util.hnsw.*;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
/**
|
|
@ -0,0 +1,55 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import org.apache.lucene.backward_codecs.lucene90.Lucene90RWSegmentInfoFormat;
|
||||
import org.apache.lucene.backward_codecs.lucene94.Lucene94HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.SegmentInfoFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
|
||||
/** Implements the Lucene 9.5 index format for backwards compat testing */
|
||||
public class Lucene95RWCodec extends Lucene95Codec {
|
||||
|
||||
private final KnnVectorsFormat defaultKnnVectorsFormat;
|
||||
private final KnnVectorsFormat knnVectorsFormat =
|
||||
new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return defaultKnnVectorsFormat;
|
||||
}
|
||||
};
|
||||
private final SegmentInfoFormat segmentInfosFormat = new Lucene90RWSegmentInfoFormat();
|
||||
|
||||
/** Instantiates a new codec. */
|
||||
public Lucene95RWCodec() {
|
||||
defaultKnnVectorsFormat =
|
||||
new Lucene95RWHnswVectorsFormat(
|
||||
Lucene94HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene94HnswVectorsFormat.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public final KnnVectorsFormat knnVectorsFormat() {
|
||||
return knnVectorsFormat;
|
||||
}
|
||||
|
||||
@Override
|
||||
public SegmentInfoFormat segmentInfoFormat() {
|
||||
return segmentInfosFormat;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
|
||||
public final class Lucene95RWHnswVectorsFormat extends Lucene95HnswVectorsFormat {
|
||||
|
||||
public Lucene95RWHnswVectorsFormat(int maxConn, int beamWidth) {
|
||||
super(maxConn, beamWidth);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene95HnswVectorsWriter(state, DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return new Lucene95HnswVectorsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Lucene95RWHnswVectorsFormat(name=Lucene95RWHnswVectorsFormat, maxConn="
|
||||
+ maxConn
|
||||
+ ", beamWidth="
|
||||
+ beamWidth
|
||||
+ ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.backward_codecs.lucene95;
|
||||
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
|
||||
public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene95RWCodec();
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
Lucene95RWCodec customCodec =
|
||||
new Lucene95RWCodec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95RWHnswVectorsFormat(10, 20);
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene95RWHnswVectorsFormat(name=Lucene95RWHnswVectorsFormat, maxConn=10, beamWidth=20)";
|
||||
assertEquals(expectedString, customCodec.getKnnVectorsFormatForField("bogus_field").toString());
|
||||
}
|
||||
}
|
|
@ -15,8 +15,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
|
||||
/** Lucene Core. */
|
||||
@SuppressWarnings("module") // the test framework is compiled after the core...
|
||||
|
@ -70,7 +70,7 @@ module org.apache.lucene.core {
|
|||
provides org.apache.lucene.codecs.DocValuesFormat with
|
||||
org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
provides org.apache.lucene.codecs.KnnVectorsFormat with
|
||||
Lucene95HnswVectorsFormat;
|
||||
Lucene99HnswVectorsFormat;
|
||||
provides org.apache.lucene.codecs.PostingsFormat with
|
||||
org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
|
||||
provides org.apache.lucene.index.SortFieldProvider with
|
||||
|
|
|
@ -139,7 +139,7 @@ public abstract class KnnVectorsWriter implements Accountable, Closeable {
|
|||
}
|
||||
|
||||
/** View over multiple vector values supporting iterator-style access via DocIdMerger. */
|
||||
protected static final class MergedVectorValues {
|
||||
public static final class MergedVectorValues {
|
||||
private MergedVectorValues() {}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link FloatVectorValues}. */
|
||||
|
|
|
@ -81,12 +81,12 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
long vectorDataLength,
|
||||
IndexInput vectorData)
|
||||
throws IOException {
|
||||
if (configuration.docsWithFieldOffset == -2 || vectorEncoding != VectorEncoding.BYTE) {
|
||||
if (configuration.isEmpty() || vectorEncoding != VectorEncoding.BYTE) {
|
||||
return new EmptyOffHeapVectorValues(dimension);
|
||||
}
|
||||
IndexInput bytesSlice = vectorData.slice("vector-data", vectorDataOffset, vectorDataLength);
|
||||
int byteSize = dimension;
|
||||
if (configuration.docsWithFieldOffset == -1) {
|
||||
if (configuration.isDense()) {
|
||||
return new DenseOffHeapVectorValues(dimension, configuration.size, bytesSlice, byteSize);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(
|
||||
|
@ -94,9 +94,13 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
}
|
||||
|
||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
public abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
|
||||
static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
/**
|
||||
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
||||
* vector.
|
||||
*/
|
||||
public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
|
@ -134,7 +138,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
}
|
||||
|
@ -203,7 +207,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
|
@ -275,7 +279,7 @@ public abstract class OffHeapByteVectorValues extends ByteVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -88,9 +88,13 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
}
|
||||
|
||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
public abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
|
||||
static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
/**
|
||||
* Dense vector values that are stored off-heap. This is the most common case when every doc has a
|
||||
* vector.
|
||||
*/
|
||||
public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
|
@ -128,7 +132,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
}
|
||||
|
@ -197,7 +201,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
|
@ -269,7 +273,7 @@ public abstract class OffHeapFloatVectorValues extends FloatVectorValues
|
|||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
public Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import org.apache.lucene.index.DocsWithFieldSet;
|
|||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
@ -191,4 +192,48 @@ public class OrdToDocDISIReaderConfiguration implements Accountable {
|
|||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE + RamUsageEstimator.sizeOf(meta);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dataIn the dataIn
|
||||
* @return the IndexedDISI for sparse values
|
||||
* @throws IOException thrown when reading data fails
|
||||
*/
|
||||
public IndexedDISI getIndexedDISI(IndexInput dataIn) throws IOException {
|
||||
assert docsWithFieldOffset > -1;
|
||||
return new IndexedDISI(
|
||||
dataIn,
|
||||
docsWithFieldOffset,
|
||||
docsWithFieldLength,
|
||||
jumpTableEntryCount,
|
||||
denseRankPower,
|
||||
size);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param dataIn the dataIn
|
||||
* @return the DirectMonotonicReader for sparse values
|
||||
* @throws IOException thrown when reading data fails
|
||||
*/
|
||||
public DirectMonotonicReader getDirectMonotonicReader(IndexInput dataIn) throws IOException {
|
||||
assert docsWithFieldOffset > -1;
|
||||
final RandomAccessInput addressesData =
|
||||
dataIn.randomAccessSlice(addressesOffset, addressesLength);
|
||||
return DirectMonotonicReader.getInstance(meta, addressesData);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return If true, the field is empty, no vector values. If false, the field is either dense or
|
||||
* sparse.
|
||||
*/
|
||||
public boolean isEmpty() {
|
||||
return docsWithFieldOffset == -2;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return If true, the field is dense, all documents have values for a field. If false, the field
|
||||
* is sparse, some documents missing values.
|
||||
*/
|
||||
public boolean isDense() {
|
||||
return docsWithFieldOffset == -1;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ import java.util.Objects;
|
|||
import org.apache.lucene.codecs.*;
|
||||
import org.apache.lucene.codecs.lucene90.*;
|
||||
import org.apache.lucene.codecs.lucene94.Lucene94FieldInfosFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
|
||||
|
@ -101,7 +100,7 @@ public class Lucene99Codec extends Codec {
|
|||
new Lucene90StoredFieldsFormat(Objects.requireNonNull(mode).storedMode);
|
||||
this.defaultPostingsFormat = new Lucene90PostingsFormat();
|
||||
this.defaultDVFormat = new Lucene90DocValuesFormat();
|
||||
this.defaultKnnVectorsFormat = new Lucene95HnswVectorsFormat();
|
||||
this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat();
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,231 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
|
||||
/**
|
||||
* Lucene 9.9 vector format, which encodes numeric vector values and an optional associated graph
|
||||
* connecting the documents having values. The graph is used to power HNSW search. The format
|
||||
* consists of three files, with an optional fourth file:
|
||||
*
|
||||
* <h2>.vec (vector data) file</h2>
|
||||
*
|
||||
* <p>For each field:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Vector data ordered by field, document ordinal, and vector dimension. When the
|
||||
* vectorEncoding is BYTE, each sample is stored as a single byte. When it is FLOAT32, each
|
||||
* sample is stored as an IEEE float in little-endian byte order.
|
||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
||||
* note that only in sparse case
|
||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||
* that only in sparse case
|
||||
* </ul>
|
||||
*
|
||||
* <h2>.vex (vector index)</h2>
|
||||
*
|
||||
* <p>Stores graphs connecting the documents for each field organized as a list of nodes' neighbours
|
||||
* as following:
|
||||
*
|
||||
* <ul>
|
||||
* <li>For each level:
|
||||
* <ul>
|
||||
* <li>For each node:
|
||||
* <ul>
|
||||
* <li><b>[vint]</b> the number of neighbor nodes
|
||||
* <li><b>array[vint]</b> the delta encoded neighbor ordinals
|
||||
* </ul>
|
||||
* </ul>
|
||||
* <li>After all levels are encoded memory offsets for each node's neighbor nodes encoded by
|
||||
* {@link org.apache.lucene.util.packed.DirectMonotonicWriter} are appened to the end of the
|
||||
* file.
|
||||
* </ul>
|
||||
*
|
||||
* <h2>.vem (vector metadata) file</h2>
|
||||
*
|
||||
* <p>For each field:
|
||||
*
|
||||
* <ul>
|
||||
* <li><b>[int32]</b> field number
|
||||
* <li><b>[int32]</b> vector similarity function ordinal
|
||||
* <li><b>[byte]</b> if equals to 1 indicates if the field is for quantized vectors
|
||||
* <li><b>[int32]</b> if quantized: the configured quantile float int bits.
|
||||
* <li><b>[int32]</b> if quantized: the calculated lower quantile float int32 bits.
|
||||
* <li><b>[int32]</b> if quantized: the calculated upper quantile float int32 bits.
|
||||
* <li><b>[vlong]</b> if quantized: offset to this field's vectors in the .veq file
|
||||
* <li><b>[vlong]</b> if quantized: length of this field's vectors, in bytes in the .veq file
|
||||
* <li><b>[vlong]</b> offset to this field's vectors in the .vec file
|
||||
* <li><b>[vlong]</b> length of this field's vectors, in bytes
|
||||
* <li><b>[vlong]</b> offset to this field's index in the .vex file
|
||||
* <li><b>[vlong]</b> length of this field's index data, in bytes
|
||||
* <li><b>[vint]</b> dimension of this field's vectors
|
||||
* <li><b>[int]</b> the number of documents having values for this field
|
||||
* <li><b>[int8]</b> if equals to -1, dense – all documents have values for a field. If equals to
|
||||
* 0, sparse – some documents missing values.
|
||||
* <li>DocIds were encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)}
|
||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||
* that only in sparse case
|
||||
* <li><b>[vint]</b> the maximum number of connections (neigbours) that each node can have
|
||||
* <li><b>[vint]</b> number of levels in the graph
|
||||
* <li>Graph nodes by level. For each level
|
||||
* <ul>
|
||||
* <li><b>[vint]</b> the number of nodes on this level
|
||||
* <li><b>array[vint]</b> for levels greater than 0 list of nodes on this level, stored as
|
||||
* the level 0th delta encoded nodes' ordinals.
|
||||
* </ul>
|
||||
* </ul>
|
||||
*
|
||||
* <h2>.veq (quantized vector data) file</h2>
|
||||
*
|
||||
* <p>For each field:
|
||||
*
|
||||
* <ul>
|
||||
* <li>Vector data ordered by field, document ordinal, and vector dimension. Each vector dimension
|
||||
* is stored as a single byte and every vector has a single float32 value for scoring
|
||||
* corrections.
|
||||
* <li>DocIds encoded by {@link IndexedDISI#writeBitSet(DocIdSetIterator, IndexOutput, byte)},
|
||||
* note that only in sparse case
|
||||
* <li>OrdToDoc was encoded by {@link org.apache.lucene.util.packed.DirectMonotonicWriter}, note
|
||||
* that only in sparse case
|
||||
* </ul>
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99HnswVectorsFormat extends KnnVectorsFormat {
|
||||
|
||||
static final String META_CODEC_NAME = "Lucene99HnswVectorsFormatMeta";
|
||||
static final String VECTOR_DATA_CODEC_NAME = "Lucene99HnswVectorsFormatData";
|
||||
static final String VECTOR_INDEX_CODEC_NAME = "Lucene99HnswVectorsFormatIndex";
|
||||
static final String META_EXTENSION = "vem";
|
||||
static final String VECTOR_DATA_EXTENSION = "vec";
|
||||
static final String VECTOR_INDEX_EXTENSION = "vex";
|
||||
|
||||
public static final int VERSION_START = 0;
|
||||
public static final int VERSION_CURRENT = VERSION_START;
|
||||
|
||||
/**
|
||||
* A maximum configurable maximum max conn.
|
||||
*
|
||||
* <p>NOTE: We eagerly populate `float[MAX_CONN*2]` and `int[MAX_CONN*2]`, so exceptionally large
|
||||
* numbers here will use an inordinate amount of heap
|
||||
*/
|
||||
private static final int MAXIMUM_MAX_CONN = 512;
|
||||
|
||||
/** Default number of maximum connections per node */
|
||||
public static final int DEFAULT_MAX_CONN = 16;
|
||||
|
||||
/**
|
||||
* The maximum size of the queue to maintain while searching during graph construction This
|
||||
* maximum value preserves the ratio of the DEFAULT_BEAM_WIDTH/DEFAULT_MAX_CONN i.e. `6.25 * 16 =
|
||||
* 3200`
|
||||
*/
|
||||
private static final int MAXIMUM_BEAM_WIDTH = 3200;
|
||||
|
||||
/**
|
||||
* Default number of the size of the queue maintained while searching during a graph construction.
|
||||
*/
|
||||
public static final int DEFAULT_BEAM_WIDTH = 100;
|
||||
|
||||
static final int DIRECT_MONOTONIC_BLOCK_SHIFT = 16;
|
||||
|
||||
/**
|
||||
* Controls how many of the nearest neighbor candidates are connected to the new node. Defaults to
|
||||
* {@link Lucene99HnswVectorsFormat#DEFAULT_MAX_CONN}. See {@link HnswGraph} for more details.
|
||||
*/
|
||||
private final int maxConn;
|
||||
|
||||
/**
|
||||
* The number of candidate neighbors to track while searching the graph for each newly inserted
|
||||
* node. Defaults to to {@link Lucene99HnswVectorsFormat#DEFAULT_BEAM_WIDTH}. See {@link
|
||||
* HnswGraph} for details.
|
||||
*/
|
||||
private final int beamWidth;
|
||||
|
||||
/** Should this codec scalar quantize float32 vectors and use this format */
|
||||
private final Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99HnswVectorsFormat() {
|
||||
this(DEFAULT_MAX_CONN, DEFAULT_BEAM_WIDTH, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a format using the given graph construction parameters.
|
||||
*
|
||||
* @param maxConn the maximum number of connections to a node in the HNSW graph
|
||||
* @param beamWidth the size of the queue maintained during graph construction.
|
||||
* @param scalarQuantize the scalar quantization format
|
||||
*/
|
||||
public Lucene99HnswVectorsFormat(
|
||||
int maxConn, int beamWidth, Lucene99ScalarQuantizedVectorsFormat scalarQuantize) {
|
||||
super("Lucene99HnswVectorsFormat");
|
||||
if (maxConn <= 0 || maxConn > MAXIMUM_MAX_CONN) {
|
||||
throw new IllegalArgumentException(
|
||||
"maxConn must be positive and less than or equal to"
|
||||
+ MAXIMUM_MAX_CONN
|
||||
+ "; maxConn="
|
||||
+ maxConn);
|
||||
}
|
||||
if (beamWidth <= 0 || beamWidth > MAXIMUM_BEAM_WIDTH) {
|
||||
throw new IllegalArgumentException(
|
||||
"beamWidth must be positive and less than or equal to"
|
||||
+ MAXIMUM_BEAM_WIDTH
|
||||
+ "; beamWidth="
|
||||
+ beamWidth);
|
||||
}
|
||||
this.maxConn = maxConn;
|
||||
this.beamWidth = beamWidth;
|
||||
this.scalarQuantizedVectorsFormat = scalarQuantize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException {
|
||||
return new Lucene99HnswVectorsWriter(state, maxConn, beamWidth, scalarQuantizedVectorsFormat);
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException {
|
||||
return new Lucene99HnswVectorsReader(state);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int getMaxDimensions(String fieldName) {
|
||||
return 1024;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn="
|
||||
+ maxConn
|
||||
+ ", beamWidth="
|
||||
+ beamWidth
|
||||
+ ", quantizer="
|
||||
+ (scalarQuantizedVectorsFormat == null ? "none" : scalarQuantizedVectorsFormat.toString())
|
||||
+ ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,634 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.HnswGraphProvider;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.CorruptIndexException;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FieldInfos;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.KnnCollector;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.DataInput;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.RandomAccessInput;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
|
||||
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/**
|
||||
* Reads vectors from the index segments along with index data structures supporting KNN search.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99HnswVectorsReader extends KnnVectorsReader
|
||||
implements QuantizedVectorsReader, HnswGraphProvider {
|
||||
|
||||
private static final long SHALLOW_SIZE =
|
||||
RamUsageEstimator.shallowSizeOfInstance(Lucene99HnswVectorsFormat.class);
|
||||
|
||||
private final FieldInfos fieldInfos;
|
||||
private final Map<String, FieldEntry> fields = new HashMap<>();
|
||||
private final IndexInput vectorData;
|
||||
private final IndexInput vectorIndex;
|
||||
private final IndexInput quantizedVectorData;
|
||||
private final Lucene99ScalarQuantizedVectorsReader quantizedVectorsReader;
|
||||
|
||||
Lucene99HnswVectorsReader(SegmentReadState state) throws IOException {
|
||||
this.fieldInfos = state.fieldInfos;
|
||||
int versionMeta = readMetadata(state);
|
||||
boolean success = false;
|
||||
try {
|
||||
vectorData =
|
||||
openDataInput(
|
||||
state,
|
||||
versionMeta,
|
||||
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION,
|
||||
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME);
|
||||
vectorIndex =
|
||||
openDataInput(
|
||||
state,
|
||||
versionMeta,
|
||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION,
|
||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME);
|
||||
if (fields.values().stream().anyMatch(FieldEntry::hasQuantizedVectors)) {
|
||||
quantizedVectorData =
|
||||
openDataInput(
|
||||
state,
|
||||
versionMeta,
|
||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION,
|
||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME);
|
||||
quantizedVectorsReader = new Lucene99ScalarQuantizedVectorsReader(quantizedVectorData);
|
||||
} else {
|
||||
quantizedVectorData = null;
|
||||
quantizedVectorsReader = null;
|
||||
}
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private int readMetadata(SegmentReadState state) throws IOException {
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
||||
int versionMeta = -1;
|
||||
try (ChecksumIndexInput meta = state.directory.openChecksumInput(metaFileName)) {
|
||||
Throwable priorE = null;
|
||||
try {
|
||||
versionMeta =
|
||||
CodecUtil.checkIndexHeader(
|
||||
meta,
|
||||
Lucene99HnswVectorsFormat.META_CODEC_NAME,
|
||||
Lucene99HnswVectorsFormat.VERSION_START,
|
||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
readFields(meta, state.fieldInfos);
|
||||
} catch (Throwable exception) {
|
||||
priorE = exception;
|
||||
} finally {
|
||||
CodecUtil.checkFooter(meta, priorE);
|
||||
}
|
||||
}
|
||||
return versionMeta;
|
||||
}
|
||||
|
||||
private static IndexInput openDataInput(
|
||||
SegmentReadState state, int versionMeta, String fileExtension, String codecName)
|
||||
throws IOException {
|
||||
String fileName =
|
||||
IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension);
|
||||
IndexInput in = state.directory.openInput(fileName, state.context);
|
||||
boolean success = false;
|
||||
try {
|
||||
int versionVectorData =
|
||||
CodecUtil.checkIndexHeader(
|
||||
in,
|
||||
codecName,
|
||||
Lucene99HnswVectorsFormat.VERSION_START,
|
||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
if (versionMeta != versionVectorData) {
|
||||
throw new CorruptIndexException(
|
||||
"Format versions mismatch: meta="
|
||||
+ versionMeta
|
||||
+ ", "
|
||||
+ codecName
|
||||
+ "="
|
||||
+ versionVectorData,
|
||||
in);
|
||||
}
|
||||
CodecUtil.retrieveChecksum(in);
|
||||
success = true;
|
||||
return in;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(in);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOException {
|
||||
for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) {
|
||||
FieldInfo info = infos.fieldInfo(fieldNumber);
|
||||
if (info == null) {
|
||||
throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta);
|
||||
}
|
||||
FieldEntry fieldEntry = readField(meta);
|
||||
validateFieldEntry(info, fieldEntry);
|
||||
fields.put(info.name, fieldEntry);
|
||||
}
|
||||
}
|
||||
|
||||
private void validateFieldEntry(FieldInfo info, FieldEntry fieldEntry) {
|
||||
int dimension = info.getVectorDimension();
|
||||
if (dimension != fieldEntry.dimension) {
|
||||
throw new IllegalStateException(
|
||||
"Inconsistent vector dimension for field=\""
|
||||
+ info.name
|
||||
+ "\"; "
|
||||
+ dimension
|
||||
+ " != "
|
||||
+ fieldEntry.dimension);
|
||||
}
|
||||
|
||||
int byteSize =
|
||||
switch (info.getVectorEncoding()) {
|
||||
case BYTE -> Byte.BYTES;
|
||||
case FLOAT32 -> Float.BYTES;
|
||||
};
|
||||
long vectorBytes = Math.multiplyExact((long) dimension, byteSize);
|
||||
long numBytes = Math.multiplyExact(vectorBytes, fieldEntry.size);
|
||||
if (numBytes != fieldEntry.vectorDataLength) {
|
||||
throw new IllegalStateException(
|
||||
"Vector data length "
|
||||
+ fieldEntry.vectorDataLength
|
||||
+ " not matching size="
|
||||
+ fieldEntry.size
|
||||
+ " * dim="
|
||||
+ dimension
|
||||
+ " * byteSize="
|
||||
+ byteSize
|
||||
+ " = "
|
||||
+ numBytes);
|
||||
}
|
||||
if (fieldEntry.hasQuantizedVectors()) {
|
||||
Lucene99ScalarQuantizedVectorsReader.validateFieldEntry(
|
||||
info, fieldEntry.dimension, fieldEntry.size, fieldEntry.quantizedVectorDataLength);
|
||||
}
|
||||
}
|
||||
|
||||
private VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException {
|
||||
int similarityFunctionId = input.readInt();
|
||||
if (similarityFunctionId < 0
|
||||
|| similarityFunctionId >= VectorSimilarityFunction.values().length) {
|
||||
throw new CorruptIndexException(
|
||||
"Invalid similarity function id: " + similarityFunctionId, input);
|
||||
}
|
||||
return VectorSimilarityFunction.values()[similarityFunctionId];
|
||||
}
|
||||
|
||||
private VectorEncoding readVectorEncoding(DataInput input) throws IOException {
|
||||
int encodingId = input.readInt();
|
||||
if (encodingId < 0 || encodingId >= VectorEncoding.values().length) {
|
||||
throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input);
|
||||
}
|
||||
return VectorEncoding.values()[encodingId];
|
||||
}
|
||||
|
||||
private FieldEntry readField(IndexInput input) throws IOException {
|
||||
VectorEncoding vectorEncoding = readVectorEncoding(input);
|
||||
VectorSimilarityFunction similarityFunction = readSimilarityFunction(input);
|
||||
return new FieldEntry(input, vectorEncoding, similarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return Lucene99HnswVectorsReader.SHALLOW_SIZE
|
||||
+ RamUsageEstimator.sizeOfMap(
|
||||
fields, RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkIntegrity() throws IOException {
|
||||
CodecUtil.checksumEntireFile(vectorData);
|
||||
CodecUtil.checksumEntireFile(vectorIndex);
|
||||
if (quantizedVectorsReader != null) {
|
||||
quantizedVectorsReader.checkIntegrity();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapFloatVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ByteVectorValues getByteVectorValues(String field) throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
throw new IllegalArgumentException(
|
||||
"field=\""
|
||||
+ field
|
||||
+ "\" is encoded as: "
|
||||
+ fieldEntry.vectorEncoding
|
||||
+ " expected: "
|
||||
+ VectorEncoding.FLOAT32);
|
||||
}
|
||||
return OffHeapByteVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
vectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
|
||||
return;
|
||||
}
|
||||
if (fieldEntry.hasQuantizedVectors()) {
|
||||
OffHeapQuantizedByteVectorValues vectorValues =
|
||||
quantizedVectorsReader.getQuantizedVectorValues(
|
||||
fieldEntry.quantizedOrdToDoc,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.quantizedVectorDataOffset,
|
||||
fieldEntry.quantizedVectorDataLength);
|
||||
if (vectorValues == null) {
|
||||
return;
|
||||
}
|
||||
RandomVectorScorer scorer =
|
||||
new ScalarQuantizedRandomVectorScorer(
|
||||
fieldEntry.similarityFunction, fieldEntry.scalarQuantizer, vectorValues, target);
|
||||
HnswGraphSearcher.search(
|
||||
scorer,
|
||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs));
|
||||
} else {
|
||||
OffHeapFloatVectorValues vectorValues =
|
||||
OffHeapFloatVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
vectorData);
|
||||
RandomVectorScorer scorer =
|
||||
RandomVectorScorer.createFloats(vectorValues, fieldEntry.similarityFunction, target);
|
||||
HnswGraphSearcher.search(
|
||||
scorer,
|
||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
|
||||
if (fieldEntry.size() == 0
|
||||
|| knnCollector.k() == 0
|
||||
|| fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
|
||||
return;
|
||||
}
|
||||
|
||||
OffHeapByteVectorValues vectorValues =
|
||||
OffHeapByteVectorValues.load(
|
||||
fieldEntry.ordToDoc,
|
||||
fieldEntry.vectorEncoding,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.vectorDataOffset,
|
||||
fieldEntry.vectorDataLength,
|
||||
vectorData);
|
||||
RandomVectorScorer scorer =
|
||||
RandomVectorScorer.createBytes(vectorValues, fieldEntry.similarityFunction, target);
|
||||
HnswGraphSearcher.search(
|
||||
scorer,
|
||||
new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc),
|
||||
getGraph(fieldEntry),
|
||||
vectorValues.getAcceptOrds(acceptDocs));
|
||||
}
|
||||
|
||||
@Override
|
||||
public HnswGraph getGraph(String field) throws IOException {
|
||||
FieldInfo info = fieldInfos.fieldInfo(field);
|
||||
if (info == null) {
|
||||
throw new IllegalArgumentException("No such field '" + field + "'");
|
||||
}
|
||||
FieldEntry entry = fields.get(field);
|
||||
if (entry != null && entry.vectorIndexLength > 0) {
|
||||
return getGraph(entry);
|
||||
} else {
|
||||
return HnswGraph.EMPTY;
|
||||
}
|
||||
}
|
||||
|
||||
private HnswGraph getGraph(FieldEntry entry) throws IOException {
|
||||
return new OffHeapHnswGraph(entry, vectorIndex);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(vectorData, vectorIndex, quantizedVectorData);
|
||||
}
|
||||
|
||||
@Override
|
||||
public OffHeapQuantizedByteVectorValues getQuantizedVectorValues(String field)
|
||||
throws IOException {
|
||||
FieldEntry fieldEntry = fields.get(field);
|
||||
if (fieldEntry == null || fieldEntry.hasQuantizedVectors() == false) {
|
||||
return null;
|
||||
}
|
||||
assert quantizedVectorsReader != null && fieldEntry.quantizedOrdToDoc != null;
|
||||
return quantizedVectorsReader.getQuantizedVectorValues(
|
||||
fieldEntry.quantizedOrdToDoc,
|
||||
fieldEntry.dimension,
|
||||
fieldEntry.size,
|
||||
fieldEntry.quantizedVectorDataOffset,
|
||||
fieldEntry.quantizedVectorDataLength);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScalarQuantizer getQuantizationState(String fieldName) {
|
||||
FieldEntry field = fields.get(fieldName);
|
||||
if (field == null || field.hasQuantizedVectors() == false) {
|
||||
return null;
|
||||
}
|
||||
return field.scalarQuantizer;
|
||||
}
|
||||
|
||||
static class FieldEntry implements Accountable {
|
||||
private static final long SHALLOW_SIZE =
|
||||
RamUsageEstimator.shallowSizeOfInstance(FieldEntry.class);
|
||||
final VectorSimilarityFunction similarityFunction;
|
||||
final VectorEncoding vectorEncoding;
|
||||
final long vectorDataOffset;
|
||||
final long vectorDataLength;
|
||||
final long vectorIndexOffset;
|
||||
final long vectorIndexLength;
|
||||
final int M;
|
||||
final int numLevels;
|
||||
final int dimension;
|
||||
final int size;
|
||||
final int[][] nodesByLevel;
|
||||
// for each level the start offsets in vectorIndex file from where to read neighbours
|
||||
final DirectMonotonicReader.Meta offsetsMeta;
|
||||
final long offsetsOffset;
|
||||
final int offsetsBlockShift;
|
||||
final long offsetsLength;
|
||||
final OrdToDocDISIReaderConfiguration ordToDoc;
|
||||
|
||||
final float configuredQuantile, lowerQuantile, upperQuantile;
|
||||
final long quantizedVectorDataOffset, quantizedVectorDataLength;
|
||||
final ScalarQuantizer scalarQuantizer;
|
||||
final boolean isQuantized;
|
||||
final OrdToDocDISIReaderConfiguration quantizedOrdToDoc;
|
||||
|
||||
FieldEntry(
|
||||
IndexInput input,
|
||||
VectorEncoding vectorEncoding,
|
||||
VectorSimilarityFunction similarityFunction)
|
||||
throws IOException {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.vectorEncoding = vectorEncoding;
|
||||
this.isQuantized = input.readByte() == 1;
|
||||
// Has int8 quantization
|
||||
if (isQuantized) {
|
||||
configuredQuantile = Float.intBitsToFloat(input.readInt());
|
||||
lowerQuantile = Float.intBitsToFloat(input.readInt());
|
||||
upperQuantile = Float.intBitsToFloat(input.readInt());
|
||||
quantizedVectorDataOffset = input.readVLong();
|
||||
quantizedVectorDataLength = input.readVLong();
|
||||
scalarQuantizer = new ScalarQuantizer(lowerQuantile, upperQuantile, configuredQuantile);
|
||||
} else {
|
||||
configuredQuantile = -1;
|
||||
lowerQuantile = -1;
|
||||
upperQuantile = -1;
|
||||
quantizedVectorDataOffset = -1;
|
||||
quantizedVectorDataLength = -1;
|
||||
scalarQuantizer = null;
|
||||
}
|
||||
vectorDataOffset = input.readVLong();
|
||||
vectorDataLength = input.readVLong();
|
||||
vectorIndexOffset = input.readVLong();
|
||||
vectorIndexLength = input.readVLong();
|
||||
dimension = input.readVInt();
|
||||
size = input.readInt();
|
||||
if (isQuantized) {
|
||||
quantizedOrdToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
||||
} else {
|
||||
quantizedOrdToDoc = null;
|
||||
}
|
||||
ordToDoc = OrdToDocDISIReaderConfiguration.fromStoredMeta(input, size);
|
||||
|
||||
// read nodes by level
|
||||
M = input.readVInt();
|
||||
numLevels = input.readVInt();
|
||||
nodesByLevel = new int[numLevels][];
|
||||
long numberOfOffsets = 0;
|
||||
for (int level = 0; level < numLevels; level++) {
|
||||
if (level > 0) {
|
||||
int numNodesOnLevel = input.readVInt();
|
||||
numberOfOffsets += numNodesOnLevel;
|
||||
nodesByLevel[level] = new int[numNodesOnLevel];
|
||||
nodesByLevel[level][0] = input.readVInt();
|
||||
for (int i = 1; i < numNodesOnLevel; i++) {
|
||||
nodesByLevel[level][i] = nodesByLevel[level][i - 1] + input.readVInt();
|
||||
}
|
||||
} else {
|
||||
numberOfOffsets += size;
|
||||
}
|
||||
}
|
||||
if (numberOfOffsets > 0) {
|
||||
offsetsOffset = input.readLong();
|
||||
offsetsBlockShift = input.readVInt();
|
||||
offsetsMeta = DirectMonotonicReader.loadMeta(input, numberOfOffsets, offsetsBlockShift);
|
||||
offsetsLength = input.readLong();
|
||||
} else {
|
||||
offsetsOffset = 0;
|
||||
offsetsBlockShift = 0;
|
||||
offsetsMeta = null;
|
||||
offsetsLength = 0;
|
||||
}
|
||||
}
|
||||
|
||||
int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
boolean hasQuantizedVectors() {
|
||||
return isQuantized;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return SHALLOW_SIZE
|
||||
+ Arrays.stream(nodesByLevel).mapToLong(nodes -> RamUsageEstimator.sizeOf(nodes)).sum()
|
||||
+ RamUsageEstimator.sizeOf(ordToDoc)
|
||||
+ (quantizedOrdToDoc == null ? 0 : RamUsageEstimator.sizeOf(quantizedOrdToDoc))
|
||||
+ RamUsageEstimator.sizeOf(offsetsMeta);
|
||||
}
|
||||
}
|
||||
|
||||
/** Read the nearest-neighbors graph from the index input */
|
||||
private static final class OffHeapHnswGraph extends HnswGraph {
|
||||
|
||||
final IndexInput dataIn;
|
||||
final int[][] nodesByLevel;
|
||||
final int numLevels;
|
||||
final int entryNode;
|
||||
final int size;
|
||||
int arcCount;
|
||||
int arcUpTo;
|
||||
int arc;
|
||||
private final DirectMonotonicReader graphLevelNodeOffsets;
|
||||
private final long[] graphLevelNodeIndexOffsets;
|
||||
// Allocated to be M*2 to track the current neighbors being explored
|
||||
private final int[] currentNeighborsBuffer;
|
||||
|
||||
OffHeapHnswGraph(FieldEntry entry, IndexInput vectorIndex) throws IOException {
|
||||
this.dataIn =
|
||||
vectorIndex.slice("graph-data", entry.vectorIndexOffset, entry.vectorIndexLength);
|
||||
this.nodesByLevel = entry.nodesByLevel;
|
||||
this.numLevels = entry.numLevels;
|
||||
this.entryNode = numLevels > 1 ? nodesByLevel[numLevels - 1][0] : 0;
|
||||
this.size = entry.size();
|
||||
final RandomAccessInput addressesData =
|
||||
vectorIndex.randomAccessSlice(entry.offsetsOffset, entry.offsetsLength);
|
||||
this.graphLevelNodeOffsets =
|
||||
DirectMonotonicReader.getInstance(entry.offsetsMeta, addressesData);
|
||||
this.currentNeighborsBuffer = new int[entry.M * 2];
|
||||
graphLevelNodeIndexOffsets = new long[numLevels];
|
||||
graphLevelNodeIndexOffsets[0] = 0;
|
||||
for (int i = 1; i < numLevels; i++) {
|
||||
// nodesByLevel is `null` for the zeroth level as we know its size
|
||||
int nodeCount = nodesByLevel[i - 1] == null ? size : nodesByLevel[i - 1].length;
|
||||
graphLevelNodeIndexOffsets[i] = graphLevelNodeIndexOffsets[i - 1] + nodeCount;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int targetOrd) throws IOException {
|
||||
int targetIndex =
|
||||
level == 0
|
||||
? targetOrd
|
||||
: Arrays.binarySearch(nodesByLevel[level], 0, nodesByLevel[level].length, targetOrd);
|
||||
assert targetIndex >= 0;
|
||||
// unsafe; no bounds checking
|
||||
dataIn.seek(graphLevelNodeOffsets.get(targetIndex + graphLevelNodeIndexOffsets[level]));
|
||||
arcCount = dataIn.readVInt();
|
||||
if (arcCount > 0) {
|
||||
currentNeighborsBuffer[0] = dataIn.readVInt();
|
||||
for (int i = 1; i < arcCount; i++) {
|
||||
currentNeighborsBuffer[i] = currentNeighborsBuffer[i - 1] + dataIn.readVInt();
|
||||
}
|
||||
}
|
||||
arc = -1;
|
||||
arcUpTo = 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextNeighbor() throws IOException {
|
||||
if (arcUpTo >= arcCount) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
arc = currentNeighborsBuffer[arcUpTo];
|
||||
++arcUpTo;
|
||||
return arc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() throws IOException {
|
||||
return numLevels;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() throws IOException {
|
||||
return entryNode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return new ArrayNodesIterator(size());
|
||||
} else {
|
||||
return new ArrayNodesIterator(nodesByLevel[level], nodesByLevel[level].length);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,985 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexFileNames;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph;
|
||||
import org.apache.lucene.util.hnsw.HnswGraph.NodesIterator;
|
||||
import org.apache.lucene.util.hnsw.HnswGraphBuilder;
|
||||
import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger;
|
||||
import org.apache.lucene.util.hnsw.NeighborArray;
|
||||
import org.apache.lucene.util.hnsw.OnHeapHnswGraph;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicWriter;
|
||||
|
||||
/**
|
||||
* Writes vector values and knn graphs to index segments.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99HnswVectorsWriter extends KnnVectorsWriter {
|
||||
|
||||
private final SegmentWriteState segmentWriteState;
|
||||
private final IndexOutput meta, vectorData, quantizedVectorData, vectorIndex;
|
||||
private final int M;
|
||||
private final int beamWidth;
|
||||
private final Lucene99ScalarQuantizedVectorsWriter quantizedVectorsWriter;
|
||||
|
||||
private final List<FieldWriter<?>> fields = new ArrayList<>();
|
||||
private boolean finished;
|
||||
|
||||
Lucene99HnswVectorsWriter(
|
||||
SegmentWriteState state,
|
||||
int M,
|
||||
int beamWidth,
|
||||
Lucene99ScalarQuantizedVectorsFormat quantizedVectorsFormat)
|
||||
throws IOException {
|
||||
this.M = M;
|
||||
this.beamWidth = beamWidth;
|
||||
segmentWriteState = state;
|
||||
String metaFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name, state.segmentSuffix, Lucene99HnswVectorsFormat.META_EXTENSION);
|
||||
|
||||
String vectorDataFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
Lucene99HnswVectorsFormat.VECTOR_DATA_EXTENSION);
|
||||
|
||||
String indexDataFileName =
|
||||
IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_EXTENSION);
|
||||
|
||||
final String quantizedVectorDataFileName =
|
||||
quantizedVectorsFormat != null
|
||||
? IndexFileNames.segmentFileName(
|
||||
state.segmentInfo.name,
|
||||
state.segmentSuffix,
|
||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_EXTENSION)
|
||||
: null;
|
||||
boolean success = false;
|
||||
try {
|
||||
meta = state.directory.createOutput(metaFileName, state.context);
|
||||
vectorData = state.directory.createOutput(vectorDataFileName, state.context);
|
||||
vectorIndex = state.directory.createOutput(indexDataFileName, state.context);
|
||||
|
||||
CodecUtil.writeIndexHeader(
|
||||
meta,
|
||||
Lucene99HnswVectorsFormat.META_CODEC_NAME,
|
||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
CodecUtil.writeIndexHeader(
|
||||
vectorData,
|
||||
Lucene99HnswVectorsFormat.VECTOR_DATA_CODEC_NAME,
|
||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
CodecUtil.writeIndexHeader(
|
||||
vectorIndex,
|
||||
Lucene99HnswVectorsFormat.VECTOR_INDEX_CODEC_NAME,
|
||||
Lucene99HnswVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
if (quantizedVectorDataFileName != null) {
|
||||
quantizedVectorData =
|
||||
state.directory.createOutput(quantizedVectorDataFileName, state.context);
|
||||
CodecUtil.writeIndexHeader(
|
||||
quantizedVectorData,
|
||||
Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_DATA_CODEC_NAME,
|
||||
Lucene99ScalarQuantizedVectorsFormat.VERSION_CURRENT,
|
||||
state.segmentInfo.getId(),
|
||||
state.segmentSuffix);
|
||||
quantizedVectorsWriter =
|
||||
new Lucene99ScalarQuantizedVectorsWriter(
|
||||
quantizedVectorData, quantizedVectorsFormat.quantile);
|
||||
} else {
|
||||
quantizedVectorData = null;
|
||||
quantizedVectorsWriter = null;
|
||||
}
|
||||
success = true;
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(this);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOException {
|
||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedVectorFieldWriter =
|
||||
null;
|
||||
// Quantization only supports FLOAT32 for now
|
||||
if (quantizedVectorsWriter != null
|
||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
quantizedVectorFieldWriter =
|
||||
quantizedVectorsWriter.addField(fieldInfo, segmentWriteState.infoStream);
|
||||
}
|
||||
FieldWriter<?> newField =
|
||||
FieldWriter.create(
|
||||
fieldInfo, M, beamWidth, segmentWriteState.infoStream, quantizedVectorFieldWriter);
|
||||
fields.add(newField);
|
||||
return newField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
|
||||
for (FieldWriter<?> field : fields) {
|
||||
long[] quantizedVectorOffsetAndLen = null;
|
||||
if (field.quantizedWriter != null) {
|
||||
assert quantizedVectorsWriter != null;
|
||||
quantizedVectorOffsetAndLen =
|
||||
quantizedVectorsWriter.flush(sortMap, field.quantizedWriter, field.docsWithField);
|
||||
}
|
||||
if (sortMap == null) {
|
||||
writeField(field, maxDoc, quantizedVectorOffsetAndLen);
|
||||
} else {
|
||||
writeSortingField(field, maxDoc, sortMap, quantizedVectorOffsetAndLen);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finish() throws IOException {
|
||||
if (finished) {
|
||||
throw new IllegalStateException("already finished");
|
||||
}
|
||||
finished = true;
|
||||
if (quantizedVectorsWriter != null) {
|
||||
quantizedVectorsWriter.finish();
|
||||
}
|
||||
|
||||
if (meta != null) {
|
||||
// write end of fields marker
|
||||
meta.writeInt(-1);
|
||||
CodecUtil.writeFooter(meta);
|
||||
}
|
||||
if (vectorData != null) {
|
||||
CodecUtil.writeFooter(vectorData);
|
||||
CodecUtil.writeFooter(vectorIndex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
long total = 0;
|
||||
for (FieldWriter<?> field : fields) {
|
||||
total += field.ramBytesUsed();
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
private void writeField(FieldWriter<?> fieldData, int maxDoc, long[] quantizedVecOffsetAndLen)
|
||||
throws IOException {
|
||||
// write vector values
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectors(fieldData);
|
||||
case FLOAT32 -> writeFloat32Vectors(fieldData);
|
||||
}
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||
int[][] graphLevelNodeOffsets = writeGraph(graph);
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
||||
writeMeta(
|
||||
fieldData.isQuantized(),
|
||||
fieldData.fieldInfo,
|
||||
maxDoc,
|
||||
fieldData.getConfiguredQuantile(),
|
||||
fieldData.getMinQuantile(),
|
||||
fieldData.getMaxQuantile(),
|
||||
quantizedVecOffsetAndLen,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
fieldData.docsWithField,
|
||||
graph,
|
||||
graphLevelNodeOffsets);
|
||||
}
|
||||
|
||||
private void writeFloat32Vectors(FieldWriter<?> fieldData) throws IOException {
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (Object v : fieldData.vectors) {
|
||||
buffer.asFloatBuffer().put((float[]) v);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeByteVectors(FieldWriter<?> fieldData) throws IOException {
|
||||
for (Object v : fieldData.vectors) {
|
||||
byte[] vector = (byte[]) v;
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
}
|
||||
|
||||
private void writeSortingField(
|
||||
FieldWriter<?> fieldData,
|
||||
int maxDoc,
|
||||
Sorter.DocMap sortMap,
|
||||
long[] quantizedVectorOffsetAndLen)
|
||||
throws IOException {
|
||||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
DocIdSetIterator iterator = fieldData.docsWithField.iterator();
|
||||
for (int docID = iterator.nextDoc();
|
||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docID = iterator.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
||||
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
|
||||
int ord = 0;
|
||||
int doc = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord] = docIdOffset - 1;
|
||||
oldOrdMap[docIdOffset - 1] = ord;
|
||||
newDocsWithField.add(doc);
|
||||
ord++;
|
||||
}
|
||||
doc++;
|
||||
}
|
||||
|
||||
// write vector values
|
||||
long vectorDataOffset =
|
||||
switch (fieldData.fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeSortedByteVectors(fieldData, ordMap);
|
||||
case FLOAT32 -> writeSortedFloat32Vectors(fieldData, ordMap);
|
||||
};
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
// write graph
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
OnHeapHnswGraph graph = fieldData.getGraph();
|
||||
int[][] graphLevelNodeOffsets = graph == null ? new int[0][] : new int[graph.numLevels()][];
|
||||
HnswGraph mockGraph = reconstructAndWriteGraph(graph, ordMap, oldOrdMap, graphLevelNodeOffsets);
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
|
||||
writeMeta(
|
||||
fieldData.isQuantized(),
|
||||
fieldData.fieldInfo,
|
||||
maxDoc,
|
||||
fieldData.getConfiguredQuantile(),
|
||||
fieldData.getMinQuantile(),
|
||||
fieldData.getMaxQuantile(),
|
||||
quantizedVectorOffsetAndLen,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
newDocsWithField,
|
||||
mockGraph,
|
||||
graphLevelNodeOffsets);
|
||||
}
|
||||
|
||||
private long writeSortedFloat32Vectors(FieldWriter<?> fieldData, int[] ordMap)
|
||||
throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
final ByteBuffer buffer =
|
||||
ByteBuffer.allocate(fieldData.dim * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int ordinal : ordMap) {
|
||||
float[] vector = (float[]) fieldData.vectors.get(ordinal);
|
||||
buffer.asFloatBuffer().put(vector);
|
||||
vectorData.writeBytes(buffer.array(), buffer.array().length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
||||
private long writeSortedByteVectors(FieldWriter<?> fieldData, int[] ordMap) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
for (int ordinal : ordMap) {
|
||||
byte[] vector = (byte[]) fieldData.vectors.get(ordinal);
|
||||
vectorData.writeBytes(vector, vector.length);
|
||||
}
|
||||
return vectorDataOffset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Reconstructs the graph given the old and new node ids.
|
||||
*
|
||||
* <p>Additionally, the graph node connections are written to the vectorIndex.
|
||||
*
|
||||
* @param graph The current on heap graph
|
||||
* @param newToOldMap the new node ids indexed to the old node ids
|
||||
* @param oldToNewMap the old node ids indexed to the new node ids
|
||||
* @param levelNodeOffsets where to place the new offsets for the nodes in the vector index.
|
||||
* @return The graph
|
||||
* @throws IOException if writing to vectorIndex fails
|
||||
*/
|
||||
private HnswGraph reconstructAndWriteGraph(
|
||||
OnHeapHnswGraph graph, int[] newToOldMap, int[] oldToNewMap, int[][] levelNodeOffsets)
|
||||
throws IOException {
|
||||
if (graph == null) return null;
|
||||
|
||||
List<int[]> nodesByLevel = new ArrayList<>(graph.numLevels());
|
||||
nodesByLevel.add(null);
|
||||
|
||||
int maxOrd = graph.size();
|
||||
NodesIterator nodesOnLevel0 = graph.getNodesOnLevel(0);
|
||||
levelNodeOffsets[0] = new int[nodesOnLevel0.size()];
|
||||
while (nodesOnLevel0.hasNext()) {
|
||||
int node = nodesOnLevel0.nextInt();
|
||||
NeighborArray neighbors = graph.getNeighbors(0, newToOldMap[node]);
|
||||
long offset = vectorIndex.getFilePointer();
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
|
||||
levelNodeOffsets[0][node] = Math.toIntExact(vectorIndex.getFilePointer() - offset);
|
||||
}
|
||||
|
||||
for (int level = 1; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
int[] newNodes = new int[nodesOnLevel.size()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
newNodes[n] = oldToNewMap[nodesOnLevel.nextInt()];
|
||||
}
|
||||
Arrays.sort(newNodes);
|
||||
nodesByLevel.add(newNodes);
|
||||
levelNodeOffsets[level] = new int[newNodes.length];
|
||||
int nodeOffsetIndex = 0;
|
||||
for (int node : newNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, newToOldMap[node]);
|
||||
long offset = vectorIndex.getFilePointer();
|
||||
reconstructAndWriteNeigbours(neighbors, oldToNewMap, maxOrd);
|
||||
levelNodeOffsets[level][nodeOffsetIndex++] =
|
||||
Math.toIntExact(vectorIndex.getFilePointer() - offset);
|
||||
}
|
||||
}
|
||||
return new HnswGraph() {
|
||||
@Override
|
||||
public int nextNeighbor() {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void seek(int level, int target) {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return graph.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int numLevels() {
|
||||
return graph.numLevels();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int entryNode() {
|
||||
throw new UnsupportedOperationException("Not supported on a mock graph");
|
||||
}
|
||||
|
||||
@Override
|
||||
public NodesIterator getNodesOnLevel(int level) {
|
||||
if (level == 0) {
|
||||
return graph.getNodesOnLevel(0);
|
||||
} else {
|
||||
return new ArrayNodesIterator(nodesByLevel.get(level), nodesByLevel.get(level).length);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private void reconstructAndWriteNeigbours(NeighborArray neighbors, int[] oldToNewMap, int maxOrd)
|
||||
throws IOException {
|
||||
int size = neighbors.size();
|
||||
vectorIndex.writeVInt(size);
|
||||
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
int[] nnodes = neighbors.node();
|
||||
for (int i = 0; i < size; i++) {
|
||||
nnodes[i] = oldToNewMap[nnodes[i]];
|
||||
}
|
||||
Arrays.sort(nnodes, 0, size);
|
||||
// Now that we have sorted, do delta encoding to minimize the required bits to store the
|
||||
// information
|
||||
for (int i = size - 1; i > 0; --i) {
|
||||
assert nnodes[i] < maxOrd : "node too large: " + nnodes[i] + ">=" + maxOrd;
|
||||
nnodes[i] -= nnodes[i - 1];
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
vectorIndex.writeVInt(nnodes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
long vectorDataOffset = vectorData.alignFilePointer(Float.BYTES);
|
||||
IndexOutput tempVectorData = null;
|
||||
IndexInput vectorDataInput = null;
|
||||
CloseableRandomVectorScorerSupplier scorerSupplier = null;
|
||||
boolean success = false;
|
||||
try {
|
||||
ScalarQuantizer scalarQuantizer = null;
|
||||
long[] quantizedVectorDataOffsetAndLength = null;
|
||||
// If we have configured quantization and are FLOAT32
|
||||
if (quantizedVectorsWriter != null
|
||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) {
|
||||
// We need the quantization parameters to write to the meta file
|
||||
scalarQuantizer = quantizedVectorsWriter.mergeQuantiles(fieldInfo, mergeState);
|
||||
if (segmentWriteState.infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
segmentWriteState.infoStream.message(
|
||||
QUANTIZED_VECTOR_COMPONENT,
|
||||
"Merged quantiles field: "
|
||||
+ fieldInfo.name
|
||||
+ " newly merged quantile: "
|
||||
+ scalarQuantizer);
|
||||
}
|
||||
assert scalarQuantizer != null;
|
||||
quantizedVectorDataOffsetAndLength = new long[2];
|
||||
quantizedVectorDataOffsetAndLength[0] = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
scorerSupplier =
|
||||
quantizedVectorsWriter.mergeOneField(
|
||||
segmentWriteState, fieldInfo, mergeState, scalarQuantizer);
|
||||
quantizedVectorDataOffsetAndLength[1] =
|
||||
quantizedVectorData.getFilePointer() - quantizedVectorDataOffsetAndLength[0];
|
||||
}
|
||||
final DocsWithFieldSet docsWithField;
|
||||
int byteSize = fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize;
|
||||
|
||||
// If we extract vector storage, this could be cleaner.
|
||||
// But for now, vector storage & index creation/storage live together.
|
||||
if (scorerSupplier == null) {
|
||||
tempVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
vectorData.getName(), "temp", segmentWriteState.context);
|
||||
docsWithField =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
tempVectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
CodecUtil.writeFooter(tempVectorData);
|
||||
IOUtils.close(tempVectorData);
|
||||
// copy the temporary file vectors to the actual data file
|
||||
vectorDataInput =
|
||||
segmentWriteState.directory.openInput(
|
||||
tempVectorData.getName(), segmentWriteState.context);
|
||||
vectorData.copyBytes(vectorDataInput, vectorDataInput.length() - CodecUtil.footerLength());
|
||||
CodecUtil.retrieveChecksum(vectorDataInput);
|
||||
final RandomVectorScorerSupplier innerScoreSupplier =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
||||
new OffHeapByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize),
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
|
||||
new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(),
|
||||
docsWithField.cardinality(),
|
||||
vectorDataInput,
|
||||
byteSize),
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
};
|
||||
final String tempFileName = tempVectorData.getName();
|
||||
final IndexInput finalVectorDataInput = vectorDataInput;
|
||||
scorerSupplier =
|
||||
new CloseableRandomVectorScorerSupplier() {
|
||||
boolean closed = false;
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
return innerScoreSupplier.scorer(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
if (closed) {
|
||||
return;
|
||||
}
|
||||
closed = true;
|
||||
IOUtils.close(finalVectorDataInput);
|
||||
segmentWriteState.directory.deleteFile(tempFileName);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// No need to use temporary file as we don't have to re-open for reading
|
||||
docsWithField =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> writeByteVectorData(
|
||||
vectorData, MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState));
|
||||
case FLOAT32 -> writeVectorData(
|
||||
vectorData, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState));
|
||||
};
|
||||
}
|
||||
|
||||
long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset;
|
||||
long vectorIndexOffset = vectorIndex.getFilePointer();
|
||||
// build the graph using the temporary vector data
|
||||
// we use Lucene99HnswVectorsReader.DenseOffHeapVectorValues for the graph construction
|
||||
// doesn't need to know docIds
|
||||
// TODO: separate random access vector values from DocIdSetIterator?
|
||||
OnHeapHnswGraph graph = null;
|
||||
int[][] vectorIndexNodeOffsets = null;
|
||||
if (docsWithField.cardinality() != 0) {
|
||||
// build graph
|
||||
IncrementalHnswGraphMerger merger =
|
||||
new IncrementalHnswGraphMerger(fieldInfo, scorerSupplier, M, beamWidth);
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
merger.addReader(
|
||||
mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]);
|
||||
}
|
||||
DocIdSetIterator mergedVectorIterator = null;
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState);
|
||||
case FLOAT32 -> mergedVectorIterator =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
}
|
||||
HnswGraphBuilder hnswGraphBuilder = merger.createBuilder(mergedVectorIterator);
|
||||
hnswGraphBuilder.setInfoStream(segmentWriteState.infoStream);
|
||||
graph = hnswGraphBuilder.build(docsWithField.cardinality());
|
||||
vectorIndexNodeOffsets = writeGraph(graph);
|
||||
}
|
||||
long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset;
|
||||
writeMeta(
|
||||
scalarQuantizer != null,
|
||||
fieldInfo,
|
||||
segmentWriteState.segmentInfo.maxDoc(),
|
||||
scalarQuantizer == null ? null : scalarQuantizer.getConfiguredQuantile(),
|
||||
scalarQuantizer == null ? null : scalarQuantizer.getLowerQuantile(),
|
||||
scalarQuantizer == null ? null : scalarQuantizer.getUpperQuantile(),
|
||||
quantizedVectorDataOffsetAndLength,
|
||||
vectorDataOffset,
|
||||
vectorDataLength,
|
||||
vectorIndexOffset,
|
||||
vectorIndexLength,
|
||||
docsWithField,
|
||||
graph,
|
||||
vectorIndexNodeOffsets);
|
||||
success = true;
|
||||
} finally {
|
||||
if (success) {
|
||||
IOUtils.close(scorerSupplier);
|
||||
} else {
|
||||
IOUtils.closeWhileHandlingException(scorerSupplier, vectorDataInput, tempVectorData);
|
||||
if (tempVectorData != null) {
|
||||
IOUtils.deleteFilesIgnoringExceptions(
|
||||
segmentWriteState.directory, tempVectorData.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param graph Write the graph in a compressed format
|
||||
* @return The non-cumulative offsets for the nodes. Should be used to create cumulative offsets.
|
||||
* @throws IOException if writing to vectorIndex fails
|
||||
*/
|
||||
private int[][] writeGraph(OnHeapHnswGraph graph) throws IOException {
|
||||
if (graph == null) return new int[0][0];
|
||||
// write vectors' neighbours on each level into the vectorIndex file
|
||||
int countOnLevel0 = graph.size();
|
||||
int[][] offsets = new int[graph.numLevels()][];
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
int[] sortedNodes = getSortedNodes(graph.getNodesOnLevel(level));
|
||||
offsets[level] = new int[sortedNodes.length];
|
||||
int nodeOffsetId = 0;
|
||||
for (int node : sortedNodes) {
|
||||
NeighborArray neighbors = graph.getNeighbors(level, node);
|
||||
int size = neighbors.size();
|
||||
// Write size in VInt as the neighbors list is typically small
|
||||
long offsetStart = vectorIndex.getFilePointer();
|
||||
vectorIndex.writeVInt(size);
|
||||
// Destructively modify; it's ok we are discarding it after this
|
||||
int[] nnodes = neighbors.node();
|
||||
Arrays.sort(nnodes, 0, size);
|
||||
// Now that we have sorted, do delta encoding to minimize the required bits to store the
|
||||
// information
|
||||
for (int i = size - 1; i > 0; --i) {
|
||||
assert nnodes[i] < countOnLevel0 : "node too large: " + nnodes[i] + ">=" + countOnLevel0;
|
||||
nnodes[i] -= nnodes[i - 1];
|
||||
}
|
||||
for (int i = 0; i < size; i++) {
|
||||
vectorIndex.writeVInt(nnodes[i]);
|
||||
}
|
||||
offsets[level][nodeOffsetId++] =
|
||||
Math.toIntExact(vectorIndex.getFilePointer() - offsetStart);
|
||||
}
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
public static int[] getSortedNodes(NodesIterator nodesOnLevel) {
|
||||
int[] sortedNodes = new int[nodesOnLevel.size()];
|
||||
for (int n = 0; nodesOnLevel.hasNext(); n++) {
|
||||
sortedNodes[n] = nodesOnLevel.nextInt();
|
||||
}
|
||||
Arrays.sort(sortedNodes);
|
||||
return sortedNodes;
|
||||
}
|
||||
|
||||
private void writeMeta(
|
||||
boolean isQuantized,
|
||||
FieldInfo field,
|
||||
int maxDoc,
|
||||
Float configuredQuantizationQuantile,
|
||||
Float lowerQuantile,
|
||||
Float upperQuantile,
|
||||
long[] quantizedVectorDataOffsetAndLen,
|
||||
long vectorDataOffset,
|
||||
long vectorDataLength,
|
||||
long vectorIndexOffset,
|
||||
long vectorIndexLength,
|
||||
DocsWithFieldSet docsWithField,
|
||||
HnswGraph graph,
|
||||
int[][] graphLevelNodeOffsets)
|
||||
throws IOException {
|
||||
meta.writeInt(field.number);
|
||||
meta.writeInt(field.getVectorEncoding().ordinal());
|
||||
meta.writeInt(field.getVectorSimilarityFunction().ordinal());
|
||||
meta.writeByte(isQuantized ? (byte) 1 : (byte) 0);
|
||||
if (isQuantized) {
|
||||
assert lowerQuantile != null
|
||||
&& upperQuantile != null
|
||||
&& quantizedVectorDataOffsetAndLen != null;
|
||||
assert quantizedVectorDataOffsetAndLen.length == 2;
|
||||
meta.writeInt(
|
||||
Float.floatToIntBits(
|
||||
configuredQuantizationQuantile != null
|
||||
? configuredQuantizationQuantile
|
||||
: calculateDefaultQuantile(field.getVectorDimension())));
|
||||
meta.writeInt(Float.floatToIntBits(lowerQuantile));
|
||||
meta.writeInt(Float.floatToIntBits(upperQuantile));
|
||||
meta.writeVLong(quantizedVectorDataOffsetAndLen[0]);
|
||||
meta.writeVLong(quantizedVectorDataOffsetAndLen[1]);
|
||||
} else {
|
||||
assert configuredQuantizationQuantile == null
|
||||
&& lowerQuantile == null
|
||||
&& upperQuantile == null
|
||||
&& quantizedVectorDataOffsetAndLen == null;
|
||||
}
|
||||
meta.writeVLong(vectorDataOffset);
|
||||
meta.writeVLong(vectorDataLength);
|
||||
meta.writeVLong(vectorIndexOffset);
|
||||
meta.writeVLong(vectorIndexLength);
|
||||
meta.writeVInt(field.getVectorDimension());
|
||||
|
||||
// write docIDs
|
||||
int count = docsWithField.cardinality();
|
||||
meta.writeInt(count);
|
||||
if (isQuantized) {
|
||||
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
||||
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, quantizedVectorData, count, maxDoc, docsWithField);
|
||||
}
|
||||
OrdToDocDISIReaderConfiguration.writeStoredMeta(
|
||||
DIRECT_MONOTONIC_BLOCK_SHIFT, meta, vectorData, count, maxDoc, docsWithField);
|
||||
|
||||
meta.writeVInt(M);
|
||||
// write graph nodes on each level
|
||||
if (graph == null) {
|
||||
meta.writeVInt(0);
|
||||
} else {
|
||||
meta.writeVInt(graph.numLevels());
|
||||
long valueCount = 0;
|
||||
for (int level = 0; level < graph.numLevels(); level++) {
|
||||
NodesIterator nodesOnLevel = graph.getNodesOnLevel(level);
|
||||
valueCount += nodesOnLevel.size();
|
||||
if (level > 0) {
|
||||
int[] nol = new int[nodesOnLevel.size()];
|
||||
int numberConsumed = nodesOnLevel.consume(nol);
|
||||
Arrays.sort(nol);
|
||||
assert numberConsumed == nodesOnLevel.size();
|
||||
meta.writeVInt(nol.length); // number of nodes on a level
|
||||
for (int i = nodesOnLevel.size() - 1; i > 0; --i) {
|
||||
nol[i] -= nol[i - 1];
|
||||
}
|
||||
for (int n : nol) {
|
||||
assert n >= 0 : "delta encoding for nodes failed; expected nodes to be sorted";
|
||||
meta.writeVInt(n);
|
||||
}
|
||||
} else {
|
||||
assert nodesOnLevel.size() == count : "Level 0 expects to have all nodes";
|
||||
}
|
||||
}
|
||||
long start = vectorIndex.getFilePointer();
|
||||
meta.writeLong(start);
|
||||
meta.writeVInt(DIRECT_MONOTONIC_BLOCK_SHIFT);
|
||||
final DirectMonotonicWriter memoryOffsetsWriter =
|
||||
DirectMonotonicWriter.getInstance(
|
||||
meta, vectorIndex, valueCount, DIRECT_MONOTONIC_BLOCK_SHIFT);
|
||||
long cumulativeOffsetSum = 0;
|
||||
for (int[] levelOffsets : graphLevelNodeOffsets) {
|
||||
for (int v : levelOffsets) {
|
||||
memoryOffsetsWriter.add(cumulativeOffsetSum);
|
||||
cumulativeOffsetSum += v;
|
||||
}
|
||||
}
|
||||
memoryOffsetsWriter.finish();
|
||||
meta.writeLong(vectorIndex.getFilePointer() - start);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the byte vector values to the output and returns a set of documents that contains
|
||||
* vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeByteVectorData(
|
||||
IndexOutput output, ByteVectorValues byteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = byteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = byteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = byteVectorValues.vectorValue();
|
||||
assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeVectorData(
|
||||
IndexOutput output, FloatVectorValues floatVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
ByteBuffer buffer =
|
||||
ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize)
|
||||
.order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int docV = floatVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = floatVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
float[] value = floatVectorValues.vectorValue();
|
||||
buffer.asFloatBuffer().put(value);
|
||||
output.writeBytes(buffer.array(), buffer.limit());
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
IOUtils.close(meta, vectorData, vectorIndex, quantizedVectorData);
|
||||
}
|
||||
|
||||
private abstract static class FieldWriter<T> extends KnnFieldVectorsWriter<T> {
|
||||
private final FieldInfo fieldInfo;
|
||||
private final int dim;
|
||||
private final DocsWithFieldSet docsWithField;
|
||||
private final List<T> vectors;
|
||||
private final HnswGraphBuilder hnswGraphBuilder;
|
||||
private final Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter
|
||||
quantizedWriter;
|
||||
|
||||
private int lastDocID = -1;
|
||||
private int node = 0;
|
||||
|
||||
static FieldWriter<?> create(
|
||||
FieldInfo fieldInfo,
|
||||
int M,
|
||||
int beamWidth,
|
||||
InfoStream infoStream,
|
||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter writer)
|
||||
throws IOException {
|
||||
int dim = fieldInfo.getVectorDimension();
|
||||
return switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> new FieldWriter<byte[]>(fieldInfo, M, beamWidth, infoStream, writer) {
|
||||
@Override
|
||||
public byte[] copyValue(byte[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
case FLOAT32 -> new FieldWriter<float[]>(fieldInfo, M, beamWidth, infoStream, writer) {
|
||||
@Override
|
||||
public float[] copyValue(float[] value) {
|
||||
return ArrayUtil.copyOfSubArray(value, 0, dim);
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
FieldWriter(
|
||||
FieldInfo fieldInfo,
|
||||
int M,
|
||||
int beamWidth,
|
||||
InfoStream infoStream,
|
||||
Lucene99ScalarQuantizedVectorsWriter.QuantizationFieldVectorWriter quantizedWriter)
|
||||
throws IOException {
|
||||
this.fieldInfo = fieldInfo;
|
||||
this.dim = fieldInfo.getVectorDimension();
|
||||
this.docsWithField = new DocsWithFieldSet();
|
||||
this.quantizedWriter = quantizedWriter;
|
||||
vectors = new ArrayList<>();
|
||||
if (quantizedWriter != null
|
||||
&& fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) {
|
||||
throw new IllegalArgumentException(
|
||||
"Vector encoding ["
|
||||
+ VectorEncoding.FLOAT32
|
||||
+ "] required for quantized vectors; provided="
|
||||
+ fieldInfo.getVectorEncoding());
|
||||
}
|
||||
RAVectorValues<T> raVectors = new RAVectorValues<>(vectors, dim);
|
||||
RandomVectorScorerSupplier scorerSupplier =
|
||||
switch (fieldInfo.getVectorEncoding()) {
|
||||
case BYTE -> RandomVectorScorerSupplier.createBytes(
|
||||
(RandomAccessVectorValues<byte[]>) raVectors,
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
case FLOAT32 -> RandomVectorScorerSupplier.createFloats(
|
||||
(RandomAccessVectorValues<float[]>) raVectors,
|
||||
fieldInfo.getVectorSimilarityFunction());
|
||||
};
|
||||
hnswGraphBuilder =
|
||||
HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed);
|
||||
hnswGraphBuilder.setInfoStream(infoStream);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void addValue(int docID, T vectorValue) throws IOException {
|
||||
if (docID == lastDocID) {
|
||||
throw new IllegalArgumentException(
|
||||
"VectorValuesField \""
|
||||
+ fieldInfo.name
|
||||
+ "\" appears more than once in this document (only one value is allowed per field)");
|
||||
}
|
||||
assert docID > lastDocID;
|
||||
T copy = copyValue(vectorValue);
|
||||
if (quantizedWriter != null) {
|
||||
assert vectorValue instanceof float[];
|
||||
quantizedWriter.addValue((float[]) copy);
|
||||
}
|
||||
docsWithField.add(docID);
|
||||
vectors.add(copy);
|
||||
hnswGraphBuilder.addGraphNode(node);
|
||||
node++;
|
||||
lastDocID = docID;
|
||||
}
|
||||
|
||||
OnHeapHnswGraph getGraph() {
|
||||
if (vectors.size() > 0) {
|
||||
return hnswGraphBuilder.getGraph();
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
if (vectors.size() == 0) return 0;
|
||||
long quantizationSpace = quantizedWriter != null ? quantizedWriter.ramBytesUsed() : 0L;
|
||||
return docsWithField.ramBytesUsed()
|
||||
+ (long) vectors.size()
|
||||
* (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER)
|
||||
+ (long) vectors.size()
|
||||
* fieldInfo.getVectorDimension()
|
||||
* fieldInfo.getVectorEncoding().byteSize
|
||||
+ hnswGraphBuilder.getGraph().ramBytesUsed()
|
||||
+ quantizationSpace;
|
||||
}
|
||||
|
||||
Float getConfiguredQuantile() {
|
||||
return quantizedWriter == null ? null : quantizedWriter.getQuantile();
|
||||
}
|
||||
|
||||
Float getMinQuantile() {
|
||||
return quantizedWriter == null ? null : quantizedWriter.getMinQuantile();
|
||||
}
|
||||
|
||||
Float getMaxQuantile() {
|
||||
return quantizedWriter == null ? null : quantizedWriter.getMaxQuantile();
|
||||
}
|
||||
|
||||
boolean isQuantized() {
|
||||
return quantizedWriter != null;
|
||||
}
|
||||
}
|
||||
|
||||
private static class RAVectorValues<T> implements RandomAccessVectorValues<T> {
|
||||
private final List<T> vectors;
|
||||
private final int dim;
|
||||
|
||||
RAVectorValues(List<T> vectors, int dim) {
|
||||
this.vectors = vectors;
|
||||
this.dim = dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectors.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dim;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T vectorValue(int targetOrd) throws IOException {
|
||||
return vectors.get(targetOrd);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomAccessVectorValues<T> copy() throws IOException {
|
||||
return this;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
/**
|
||||
* Format supporting vector quantization, storage, and retrieval
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99ScalarQuantizedVectorsFormat {
|
||||
public static final String QUANTIZED_VECTOR_COMPONENT = "QVEC";
|
||||
|
||||
static final String NAME = "Lucene99ScalarQuantizedVectorsFormat";
|
||||
|
||||
static final int VERSION_START = 0;
|
||||
static final int VERSION_CURRENT = VERSION_START;
|
||||
static final String QUANTIZED_VECTOR_DATA_CODEC_NAME = "Lucene99ScalarQuantizedVectorsData";
|
||||
static final String QUANTIZED_VECTOR_DATA_EXTENSION = "veq";
|
||||
|
||||
/** The minimum quantile */
|
||||
private static final float MINIMUM_QUANTILE = 0.9f;
|
||||
|
||||
/** The maximum quantile */
|
||||
private static final float MAXIMUM_QUANTILE = 1f;
|
||||
|
||||
/**
|
||||
* Controls the quantile used to scalar quantize the vectors the default quantile is calculated as
|
||||
* `1-1/(vector_dimensions + 1)`
|
||||
*/
|
||||
final Float quantile;
|
||||
|
||||
/** Constructs a format using default graph construction parameters */
|
||||
public Lucene99ScalarQuantizedVectorsFormat() {
|
||||
this(null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Constructs a format using the given graph construction parameters.
|
||||
*
|
||||
* @param quantile the quantile for scalar quantizing the vectors, when `null` it is calculated
|
||||
* based on the vector field dimensions.
|
||||
*/
|
||||
public Lucene99ScalarQuantizedVectorsFormat(Float quantile) {
|
||||
if (quantile != null && (quantile < MINIMUM_QUANTILE || quantile > MAXIMUM_QUANTILE)) {
|
||||
throw new IllegalArgumentException(
|
||||
"quantile must be between "
|
||||
+ MINIMUM_QUANTILE
|
||||
+ " and "
|
||||
+ MAXIMUM_QUANTILE
|
||||
+ "; quantile="
|
||||
+ quantile);
|
||||
}
|
||||
this.quantile = quantile;
|
||||
}
|
||||
|
||||
static float calculateDefaultQuantile(int vectorDimension) {
|
||||
return Math.max(MINIMUM_QUANTILE, 1f - (1f / (vectorDimension + 1)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return NAME + "(name=" + NAME + ", quantile=" + quantile + ")";
|
||||
}
|
||||
}
|
|
@ -0,0 +1,88 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
|
||||
/**
|
||||
* Reads Scalar Quantized vectors from the index segments along with index data structures.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99ScalarQuantizedVectorsReader {
|
||||
|
||||
private final IndexInput quantizedVectorData;
|
||||
|
||||
Lucene99ScalarQuantizedVectorsReader(IndexInput quantizedVectorData) {
|
||||
this.quantizedVectorData = quantizedVectorData;
|
||||
}
|
||||
|
||||
static void validateFieldEntry(
|
||||
FieldInfo info, int fieldDimension, int size, long quantizedVectorDataLength) {
|
||||
int dimension = info.getVectorDimension();
|
||||
if (dimension != fieldDimension) {
|
||||
throw new IllegalStateException(
|
||||
"Inconsistent vector dimension for field=\""
|
||||
+ info.name
|
||||
+ "\"; "
|
||||
+ dimension
|
||||
+ " != "
|
||||
+ fieldDimension);
|
||||
}
|
||||
|
||||
// int8 quantized and calculated stored offset.
|
||||
long quantizedVectorBytes = dimension + Float.BYTES;
|
||||
long numQuantizedVectorBytes = Math.multiplyExact(quantizedVectorBytes, size);
|
||||
if (numQuantizedVectorBytes != quantizedVectorDataLength) {
|
||||
throw new IllegalStateException(
|
||||
"Quantized vector data length "
|
||||
+ quantizedVectorDataLength
|
||||
+ " not matching size="
|
||||
+ size
|
||||
+ " * (dim="
|
||||
+ dimension
|
||||
+ " + 4)"
|
||||
+ " = "
|
||||
+ numQuantizedVectorBytes);
|
||||
}
|
||||
}
|
||||
|
||||
public void checkIntegrity() throws IOException {
|
||||
CodecUtil.checksumEntireFile(quantizedVectorData);
|
||||
}
|
||||
|
||||
OffHeapQuantizedByteVectorValues getQuantizedVectorValues(
|
||||
OrdToDocDISIReaderConfiguration configuration,
|
||||
int dimension,
|
||||
int size,
|
||||
long quantizedVectorDataOffset,
|
||||
long quantizedVectorDataLength)
|
||||
throws IOException {
|
||||
return OffHeapQuantizedByteVectorValues.load(
|
||||
configuration,
|
||||
dimension,
|
||||
size,
|
||||
quantizedVectorDataOffset,
|
||||
quantizedVectorDataLength,
|
||||
quantizedVectorData);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,824 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT;
|
||||
import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.CodecUtil;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.index.DocIDMerger;
|
||||
import org.apache.lucene.index.DocsWithFieldSet;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.MergeState;
|
||||
import org.apache.lucene.index.SegmentWriteState;
|
||||
import org.apache.lucene.index.Sorter;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.store.IndexOutput;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.IOUtils;
|
||||
import org.apache.lucene.util.InfoStream;
|
||||
import org.apache.lucene.util.RamUsageEstimator;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.CloseableRandomVectorScorerSupplier;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
/**
|
||||
* Writes quantized vector values and metadata to index segments.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||
|
||||
private static final long BASE_RAM_BYTES_USED =
|
||||
shallowSizeOfInstance(Lucene99ScalarQuantizedVectorsWriter.class);
|
||||
|
||||
// Used for determining when merged quantiles shifted too far from individual segment quantiles.
|
||||
// When merging quantiles from various segments, we need to ensure that the new quantiles
|
||||
// are not exceptionally different from an individual segments quantiles.
|
||||
// This would imply that the quantization buckets would shift too much
|
||||
// for floating point values and justify recalculating the quantiles. This helps preserve
|
||||
// accuracy of the calculated quantiles, even in adversarial cases such as vector clustering.
|
||||
// This number was determined via empirical testing
|
||||
private static final float QUANTILE_RECOMPUTE_LIMIT = 32;
|
||||
// Used for determining if a new quantization state requires a re-quantization
|
||||
// for a given segment.
|
||||
// This ensures that in expectation 4/5 of the vector would be unchanged by requantization.
|
||||
// Furthermore, only those values where the value is within 1/5 of the centre of a quantization
|
||||
// bin will be changed. In these cases the error introduced by snapping one way or another
|
||||
// is small compared to the error introduced by quantization in the first place. Furthermore,
|
||||
// empirical testing showed that the relative error by not requantizing is small (compared to
|
||||
// the quantization error) and the condition is sensitive enough to detect all adversarial cases,
|
||||
// such as merging clustered data.
|
||||
private static final float REQUANTIZATION_LIMIT = 0.2f;
|
||||
private final IndexOutput quantizedVectorData;
|
||||
private final Float quantile;
|
||||
private boolean finished;
|
||||
|
||||
Lucene99ScalarQuantizedVectorsWriter(IndexOutput quantizedVectorData, Float quantile) {
|
||||
this.quantile = quantile;
|
||||
this.quantizedVectorData = quantizedVectorData;
|
||||
}
|
||||
|
||||
QuantizationFieldVectorWriter addField(FieldInfo fieldInfo, InfoStream infoStream) {
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
throw new IllegalArgumentException(
|
||||
"Only float32 vector fields are supported for quantization");
|
||||
}
|
||||
float quantile =
|
||||
this.quantile == null
|
||||
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||
: this.quantile;
|
||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
infoStream.message(
|
||||
QUANTIZED_VECTOR_COMPONENT,
|
||||
"quantizing field="
|
||||
+ fieldInfo.name
|
||||
+ " dimension="
|
||||
+ fieldInfo.getVectorDimension()
|
||||
+ " quantile="
|
||||
+ quantile);
|
||||
}
|
||||
return QuantizationFieldVectorWriter.create(fieldInfo, quantile, infoStream);
|
||||
}
|
||||
|
||||
long[] flush(
|
||||
Sorter.DocMap sortMap, QuantizationFieldVectorWriter field, DocsWithFieldSet docsWithField)
|
||||
throws IOException {
|
||||
field.finish();
|
||||
return sortMap == null ? writeField(field) : writeSortingField(field, sortMap, docsWithField);
|
||||
}
|
||||
|
||||
void finish() throws IOException {
|
||||
if (finished) {
|
||||
throw new IllegalStateException("already finished");
|
||||
}
|
||||
finished = true;
|
||||
if (quantizedVectorData != null) {
|
||||
CodecUtil.writeFooter(quantizedVectorData);
|
||||
}
|
||||
}
|
||||
|
||||
private long[] writeField(QuantizationFieldVectorWriter fieldData) throws IOException {
|
||||
long quantizedVectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
writeQuantizedVectors(fieldData);
|
||||
long quantizedVectorDataLength =
|
||||
quantizedVectorData.getFilePointer() - quantizedVectorDataOffset;
|
||||
return new long[] {quantizedVectorDataOffset, quantizedVectorDataLength};
|
||||
}
|
||||
|
||||
private void writeQuantizedVectors(QuantizationFieldVectorWriter fieldData) throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
byte[] vector = new byte[fieldData.dim];
|
||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (float[] v : fieldData.floatVectors) {
|
||||
float offsetCorrection =
|
||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
offsetBuffer.putFloat(offsetCorrection);
|
||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||
offsetBuffer.rewind();
|
||||
}
|
||||
}
|
||||
|
||||
private long[] writeSortingField(
|
||||
QuantizationFieldVectorWriter fieldData,
|
||||
Sorter.DocMap sortMap,
|
||||
DocsWithFieldSet docsWithField)
|
||||
throws IOException {
|
||||
final int[] docIdOffsets = new int[sortMap.size()];
|
||||
int offset = 1; // 0 means no vector for this (field, document)
|
||||
DocIdSetIterator iterator = docsWithField.iterator();
|
||||
for (int docID = iterator.nextDoc();
|
||||
docID != DocIdSetIterator.NO_MORE_DOCS;
|
||||
docID = iterator.nextDoc()) {
|
||||
int newDocID = sortMap.oldToNew(docID);
|
||||
docIdOffsets[newDocID] = offset++;
|
||||
}
|
||||
DocsWithFieldSet newDocsWithField = new DocsWithFieldSet();
|
||||
final int[] ordMap = new int[offset - 1]; // new ord to old ord
|
||||
final int[] oldOrdMap = new int[offset - 1]; // old ord to new ord
|
||||
int ord = 0;
|
||||
int doc = 0;
|
||||
for (int docIdOffset : docIdOffsets) {
|
||||
if (docIdOffset != 0) {
|
||||
ordMap[ord] = docIdOffset - 1;
|
||||
oldOrdMap[docIdOffset - 1] = ord;
|
||||
newDocsWithField.add(doc);
|
||||
ord++;
|
||||
}
|
||||
doc++;
|
||||
}
|
||||
|
||||
// write vector values
|
||||
long vectorDataOffset = quantizedVectorData.alignFilePointer(Float.BYTES);
|
||||
writeSortedQuantizedVectors(fieldData, ordMap);
|
||||
long quantizedVectorLength = quantizedVectorData.getFilePointer() - vectorDataOffset;
|
||||
|
||||
return new long[] {vectorDataOffset, quantizedVectorLength};
|
||||
}
|
||||
|
||||
void writeSortedQuantizedVectors(QuantizationFieldVectorWriter fieldData, int[] ordMap)
|
||||
throws IOException {
|
||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||
byte[] vector = new byte[fieldData.dim];
|
||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||
for (int ordinal : ordMap) {
|
||||
float[] v = fieldData.floatVectors.get(ordinal);
|
||||
float offsetCorrection =
|
||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
||||
quantizedVectorData.writeBytes(vector, vector.length);
|
||||
offsetBuffer.putFloat(offsetCorrection);
|
||||
quantizedVectorData.writeBytes(offsetBuffer.array(), offsetBuffer.array().length);
|
||||
offsetBuffer.rewind();
|
||||
}
|
||||
}
|
||||
|
||||
ScalarQuantizer mergeQuantiles(FieldInfo fieldInfo, MergeState mergeState) throws IOException {
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
float quantile =
|
||||
this.quantile == null
|
||||
? calculateDefaultQuantile(fieldInfo.getVectorDimension())
|
||||
: this.quantile;
|
||||
return mergeAndRecalculateQuantiles(mergeState, fieldInfo, quantile);
|
||||
}
|
||||
|
||||
ScalarQuantizedCloseableRandomVectorScorerSupplier mergeOneField(
|
||||
SegmentWriteState segmentWriteState,
|
||||
FieldInfo fieldInfo,
|
||||
MergeState mergeState,
|
||||
ScalarQuantizer mergedQuantizationState)
|
||||
throws IOException {
|
||||
if (fieldInfo.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
IndexOutput tempQuantizedVectorData =
|
||||
segmentWriteState.directory.createTempOutput(
|
||||
quantizedVectorData.getName(), "temp", segmentWriteState.context);
|
||||
IndexInput quantizationDataInput = null;
|
||||
boolean success = false;
|
||||
try {
|
||||
MergedQuantizedVectorValues byteVectorValues =
|
||||
MergedQuantizedVectorValues.mergeQuantizedByteVectorValues(
|
||||
fieldInfo, mergeState, mergedQuantizationState);
|
||||
writeQuantizedVectorData(tempQuantizedVectorData, byteVectorValues);
|
||||
CodecUtil.writeFooter(tempQuantizedVectorData);
|
||||
IOUtils.close(tempQuantizedVectorData);
|
||||
quantizationDataInput =
|
||||
segmentWriteState.directory.openInput(
|
||||
tempQuantizedVectorData.getName(), segmentWriteState.context);
|
||||
quantizedVectorData.copyBytes(
|
||||
quantizationDataInput, quantizationDataInput.length() - CodecUtil.footerLength());
|
||||
CodecUtil.retrieveChecksum(quantizationDataInput);
|
||||
success = true;
|
||||
final IndexInput finalQuantizationDataInput = quantizationDataInput;
|
||||
return new ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
||||
() -> {
|
||||
IOUtils.close(finalQuantizationDataInput);
|
||||
segmentWriteState.directory.deleteFile(tempQuantizedVectorData.getName());
|
||||
},
|
||||
new ScalarQuantizedRandomVectorScorerSupplier(
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
mergedQuantizationState,
|
||||
new OffHeapQuantizedByteVectorValues.DenseOffHeapVectorValues(
|
||||
fieldInfo.getVectorDimension(), byteVectorValues.size(), quantizationDataInput)));
|
||||
} finally {
|
||||
if (success == false) {
|
||||
IOUtils.closeWhileHandlingException(quantizationDataInput);
|
||||
IOUtils.deleteFilesIgnoringExceptions(
|
||||
segmentWriteState.directory, tempQuantizedVectorData.getName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static ScalarQuantizer mergeQuantiles(
|
||||
List<ScalarQuantizer> quantizationStates, List<Integer> segmentSizes, float quantile) {
|
||||
assert quantizationStates.size() == segmentSizes.size();
|
||||
if (quantizationStates.isEmpty()) {
|
||||
return null;
|
||||
}
|
||||
float lowerQuantile = 0f;
|
||||
float upperQuantile = 0f;
|
||||
int totalCount = 0;
|
||||
for (int i = 0; i < quantizationStates.size(); i++) {
|
||||
if (quantizationStates.get(i) == null) {
|
||||
return null;
|
||||
}
|
||||
lowerQuantile += quantizationStates.get(i).getLowerQuantile() * segmentSizes.get(i);
|
||||
upperQuantile += quantizationStates.get(i).getUpperQuantile() * segmentSizes.get(i);
|
||||
totalCount += segmentSizes.get(i);
|
||||
}
|
||||
lowerQuantile /= totalCount;
|
||||
upperQuantile /= totalCount;
|
||||
return new ScalarQuantizer(lowerQuantile, upperQuantile, quantile);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the quantiles of the merged state are too far from the quantiles of the
|
||||
* individual states.
|
||||
*
|
||||
* @param mergedQuantizationState The merged quantization state
|
||||
* @param quantizationStates The quantization states of the individual segments
|
||||
* @return true if the quantiles should be recomputed
|
||||
*/
|
||||
static boolean shouldRecomputeQuantiles(
|
||||
ScalarQuantizer mergedQuantizationState, List<ScalarQuantizer> quantizationStates) {
|
||||
// calculate the limit for the quantiles to be considered too far apart
|
||||
// We utilize upper & lower here to determine if the new upper and merged upper would
|
||||
// drastically
|
||||
// change the quantization buckets for floats
|
||||
// This is a fairly conservative check.
|
||||
float limit =
|
||||
(mergedQuantizationState.getUpperQuantile() - mergedQuantizationState.getLowerQuantile())
|
||||
/ QUANTILE_RECOMPUTE_LIMIT;
|
||||
for (ScalarQuantizer quantizationState : quantizationStates) {
|
||||
if (Math.abs(
|
||||
quantizationState.getUpperQuantile() - mergedQuantizationState.getUpperQuantile())
|
||||
> limit) {
|
||||
return true;
|
||||
}
|
||||
if (Math.abs(
|
||||
quantizationState.getLowerQuantile() - mergedQuantizationState.getLowerQuantile())
|
||||
> limit) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static QuantizedVectorsReader getQuantizedKnnVectorsReader(
|
||||
KnnVectorsReader vectorsReader, String fieldName) {
|
||||
if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) {
|
||||
vectorsReader = candidateReader.getFieldReader(fieldName);
|
||||
}
|
||||
if (vectorsReader instanceof QuantizedVectorsReader reader) {
|
||||
return reader;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static ScalarQuantizer getQuantizedState(
|
||||
KnnVectorsReader vectorsReader, String fieldName) {
|
||||
QuantizedVectorsReader reader = getQuantizedKnnVectorsReader(vectorsReader, fieldName);
|
||||
if (reader != null) {
|
||||
return reader.getQuantizationState(fieldName);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
static ScalarQuantizer mergeAndRecalculateQuantiles(
|
||||
MergeState mergeState, FieldInfo fieldInfo, float quantile) throws IOException {
|
||||
List<ScalarQuantizer> quantizationStates = new ArrayList<>(mergeState.liveDocs.length);
|
||||
List<Integer> segmentSizes = new ArrayList<>(mergeState.liveDocs.length);
|
||||
for (int i = 0; i < mergeState.liveDocs.length; i++) {
|
||||
FloatVectorValues fvv;
|
||||
if (mergeState.knnVectorsReaders[i] != null
|
||||
&& (fvv = mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name)) != null
|
||||
&& fvv.size() > 0) {
|
||||
ScalarQuantizer quantizationState =
|
||||
getQuantizedState(mergeState.knnVectorsReaders[i], fieldInfo.name);
|
||||
// If we have quantization state, we can utilize that to make merging cheaper
|
||||
quantizationStates.add(quantizationState);
|
||||
segmentSizes.add(fvv.size());
|
||||
}
|
||||
}
|
||||
ScalarQuantizer mergedQuantiles = mergeQuantiles(quantizationStates, segmentSizes, quantile);
|
||||
// Segments no providing quantization state indicates that their quantiles were never
|
||||
// calculated.
|
||||
// To be safe, we should always recalculate given a sample set over all the float vectors in the
|
||||
// merged
|
||||
// segment view
|
||||
if (mergedQuantiles == null || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) {
|
||||
FloatVectorValues vectorValues =
|
||||
KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState);
|
||||
mergedQuantiles = ScalarQuantizer.fromVectors(vectorValues, quantile);
|
||||
}
|
||||
return mergedQuantiles;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true if the quantiles of the new quantization state are too far from the quantiles of
|
||||
* the existing quantization state. This would imply that floating point values would slightly
|
||||
* shift quantization buckets.
|
||||
*
|
||||
* @param existingQuantiles The existing quantiles for a segment
|
||||
* @param newQuantiles The new quantiles for a segment, could be merged, or fully re-calculated
|
||||
* @return true if the floating point values should be requantized
|
||||
*/
|
||||
static boolean shouldRequantize(ScalarQuantizer existingQuantiles, ScalarQuantizer newQuantiles) {
|
||||
float tol =
|
||||
REQUANTIZATION_LIMIT
|
||||
* (newQuantiles.getUpperQuantile() - newQuantiles.getLowerQuantile())
|
||||
/ 128f;
|
||||
if (Math.abs(existingQuantiles.getUpperQuantile() - newQuantiles.getUpperQuantile()) > tol) {
|
||||
return true;
|
||||
}
|
||||
return Math.abs(existingQuantiles.getLowerQuantile() - newQuantiles.getLowerQuantile()) > tol;
|
||||
}
|
||||
|
||||
/**
|
||||
* Writes the vector values to the output and returns a set of documents that contains vectors.
|
||||
*/
|
||||
private static DocsWithFieldSet writeQuantizedVectorData(
|
||||
IndexOutput output, QuantizedByteVectorValues quantizedByteVectorValues) throws IOException {
|
||||
DocsWithFieldSet docsWithField = new DocsWithFieldSet();
|
||||
for (int docV = quantizedByteVectorValues.nextDoc();
|
||||
docV != NO_MORE_DOCS;
|
||||
docV = quantizedByteVectorValues.nextDoc()) {
|
||||
// write vector
|
||||
byte[] binaryValue = quantizedByteVectorValues.vectorValue();
|
||||
assert binaryValue.length == quantizedByteVectorValues.dimension()
|
||||
: "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length;
|
||||
output.writeBytes(binaryValue, binaryValue.length);
|
||||
output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant()));
|
||||
docsWithField.add(docV);
|
||||
}
|
||||
return docsWithField;
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
return BASE_RAM_BYTES_USED;
|
||||
}
|
||||
|
||||
static class QuantizationFieldVectorWriter implements Accountable {
|
||||
private static final long SHALLOW_SIZE =
|
||||
shallowSizeOfInstance(QuantizationFieldVectorWriter.class);
|
||||
private final int dim;
|
||||
private final List<float[]> floatVectors;
|
||||
private final boolean normalize;
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
private final float quantile;
|
||||
private final InfoStream infoStream;
|
||||
private float minQuantile = Float.POSITIVE_INFINITY;
|
||||
private float maxQuantile = Float.NEGATIVE_INFINITY;
|
||||
private boolean finished;
|
||||
|
||||
static QuantizationFieldVectorWriter create(
|
||||
FieldInfo fieldInfo, float quantile, InfoStream infoStream) {
|
||||
return new QuantizationFieldVectorWriter(
|
||||
fieldInfo.getVectorDimension(),
|
||||
quantile,
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
infoStream);
|
||||
}
|
||||
|
||||
QuantizationFieldVectorWriter(
|
||||
int dim,
|
||||
float quantile,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
InfoStream infoStream) {
|
||||
this.dim = dim;
|
||||
this.quantile = quantile;
|
||||
this.normalize = vectorSimilarityFunction == VectorSimilarityFunction.COSINE;
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
this.floatVectors = new ArrayList<>();
|
||||
this.infoStream = infoStream;
|
||||
}
|
||||
|
||||
void finish() throws IOException {
|
||||
if (finished) {
|
||||
return;
|
||||
}
|
||||
if (floatVectors.size() == 0) {
|
||||
finished = true;
|
||||
return;
|
||||
}
|
||||
ScalarQuantizer quantizer =
|
||||
ScalarQuantizer.fromVectors(new FloatVectorWrapper(floatVectors, normalize), quantile);
|
||||
minQuantile = quantizer.getLowerQuantile();
|
||||
maxQuantile = quantizer.getUpperQuantile();
|
||||
if (infoStream.isEnabled(QUANTIZED_VECTOR_COMPONENT)) {
|
||||
infoStream.message(
|
||||
QUANTIZED_VECTOR_COMPONENT,
|
||||
"quantized field="
|
||||
+ " dimension="
|
||||
+ dim
|
||||
+ " quantile="
|
||||
+ quantile
|
||||
+ " minQuantile="
|
||||
+ minQuantile
|
||||
+ " maxQuantile="
|
||||
+ maxQuantile);
|
||||
}
|
||||
finished = true;
|
||||
}
|
||||
|
||||
public void addValue(float[] vectorValue) throws IOException {
|
||||
floatVectors.add(vectorValue);
|
||||
}
|
||||
|
||||
float getMinQuantile() {
|
||||
assert finished;
|
||||
return minQuantile;
|
||||
}
|
||||
|
||||
float getMaxQuantile() {
|
||||
assert finished;
|
||||
return maxQuantile;
|
||||
}
|
||||
|
||||
float getQuantile() {
|
||||
return quantile;
|
||||
}
|
||||
|
||||
ScalarQuantizer createQuantizer() {
|
||||
assert finished;
|
||||
return new ScalarQuantizer(minQuantile, maxQuantile, quantile);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long ramBytesUsed() {
|
||||
if (floatVectors.size() == 0) return SHALLOW_SIZE;
|
||||
return SHALLOW_SIZE + (long) floatVectors.size() * RamUsageEstimator.NUM_BYTES_OBJECT_REF;
|
||||
}
|
||||
}
|
||||
|
||||
static class FloatVectorWrapper extends FloatVectorValues {
|
||||
private final List<float[]> vectorList;
|
||||
private final float[] copy;
|
||||
private final boolean normalize;
|
||||
protected int curDoc = -1;
|
||||
|
||||
FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
|
||||
this.vectorList = vectorList;
|
||||
this.copy = new float[vectorList.get(0).length];
|
||||
this.normalize = normalize;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return vectorList.get(0).length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return vectorList.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= vectorList.size()) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
if (normalize) {
|
||||
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
return copy;
|
||||
}
|
||||
return vectorList.get(curDoc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curDoc >= vectorList.size()) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return curDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
curDoc++;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = target;
|
||||
return docID();
|
||||
}
|
||||
}
|
||||
|
||||
private static class QuantizedByteVectorValueSub extends DocIDMerger.Sub {
|
||||
private final QuantizedByteVectorValues values;
|
||||
|
||||
QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) {
|
||||
super(docMap);
|
||||
this.values = values;
|
||||
assert values.docID() == -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return values.nextDoc();
|
||||
}
|
||||
}
|
||||
|
||||
/** Returns a merged view over all the segment's {@link QuantizedByteVectorValues}. */
|
||||
static class MergedQuantizedVectorValues extends QuantizedByteVectorValues {
|
||||
public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues(
|
||||
FieldInfo fieldInfo, MergeState mergeState, ScalarQuantizer scalarQuantizer)
|
||||
throws IOException {
|
||||
assert fieldInfo != null && fieldInfo.hasVectorValues();
|
||||
|
||||
List<QuantizedByteVectorValueSub> subs = new ArrayList<>();
|
||||
for (int i = 0; i < mergeState.knnVectorsReaders.length; i++) {
|
||||
if (mergeState.knnVectorsReaders[i] != null
|
||||
&& mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name) != null) {
|
||||
QuantizedVectorsReader reader =
|
||||
getQuantizedKnnVectorsReader(mergeState.knnVectorsReaders[i], fieldInfo.name);
|
||||
assert scalarQuantizer != null;
|
||||
final QuantizedByteVectorValueSub sub;
|
||||
// Either our quantization parameters are way different than the merged ones
|
||||
// Or we have never been quantized.
|
||||
if (reader == null
|
||||
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
|
||||
sub =
|
||||
new QuantizedByteVectorValueSub(
|
||||
mergeState.docMaps[i],
|
||||
new QuantizedFloatVectorValues(
|
||||
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
scalarQuantizer));
|
||||
} else {
|
||||
sub =
|
||||
new QuantizedByteVectorValueSub(
|
||||
mergeState.docMaps[i],
|
||||
new OffsetCorrectedQuantizedByteVectorValues(
|
||||
reader.getQuantizedVectorValues(fieldInfo.name),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
scalarQuantizer,
|
||||
reader.getQuantizationState(fieldInfo.name)));
|
||||
}
|
||||
subs.add(sub);
|
||||
}
|
||||
}
|
||||
return new MergedQuantizedVectorValues(subs, mergeState);
|
||||
}
|
||||
|
||||
private final List<QuantizedByteVectorValueSub> subs;
|
||||
private final DocIDMerger<QuantizedByteVectorValueSub> docIdMerger;
|
||||
private final int size;
|
||||
|
||||
private int docId;
|
||||
private QuantizedByteVectorValueSub current;
|
||||
|
||||
private MergedQuantizedVectorValues(
|
||||
List<QuantizedByteVectorValueSub> subs, MergeState mergeState) throws IOException {
|
||||
this.subs = subs;
|
||||
docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort);
|
||||
int totalSize = 0;
|
||||
for (QuantizedByteVectorValueSub sub : subs) {
|
||||
totalSize += sub.values.size();
|
||||
}
|
||||
size = totalSize;
|
||||
docId = -1;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return current.values.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
current = docIdMerger.next();
|
||||
if (current == null) {
|
||||
docId = NO_MORE_DOCS;
|
||||
} else {
|
||||
docId = current.mappedDocID;
|
||||
}
|
||||
return docId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return subs.get(0).values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
float getScoreCorrectionConstant() throws IOException {
|
||||
return current.values.getScoreCorrectionConstant();
|
||||
}
|
||||
}
|
||||
|
||||
private static class QuantizedFloatVectorValues extends QuantizedByteVectorValues {
|
||||
private final FloatVectorValues values;
|
||||
private final ScalarQuantizer quantizer;
|
||||
private final byte[] quantizedVector;
|
||||
private float offsetValue = 0f;
|
||||
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
||||
public QuantizedFloatVectorValues(
|
||||
FloatVectorValues values,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
ScalarQuantizer quantizer) {
|
||||
this.values = values;
|
||||
this.quantizer = quantizer;
|
||||
this.quantizedVector = new byte[values.dimension()];
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
}
|
||||
|
||||
@Override
|
||||
float getScoreCorrectionConstant() {
|
||||
return offsetValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return values.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return quantizedVector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return values.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
int doc = values.nextDoc();
|
||||
if (doc != NO_MORE_DOCS) {
|
||||
offsetValue =
|
||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
int doc = values.advance(target);
|
||||
if (doc != NO_MORE_DOCS) {
|
||||
offsetValue =
|
||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||
}
|
||||
return doc;
|
||||
}
|
||||
}
|
||||
|
||||
static final class ScalarQuantizedCloseableRandomVectorScorerSupplier
|
||||
implements CloseableRandomVectorScorerSupplier {
|
||||
|
||||
private final ScalarQuantizedRandomVectorScorerSupplier supplier;
|
||||
private final Closeable onClose;
|
||||
|
||||
ScalarQuantizedCloseableRandomVectorScorerSupplier(
|
||||
Closeable onClose, ScalarQuantizedRandomVectorScorerSupplier supplier) {
|
||||
this.onClose = onClose;
|
||||
this.supplier = supplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
return supplier.scorer(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void close() throws IOException {
|
||||
onClose.close();
|
||||
}
|
||||
}
|
||||
|
||||
private static final class OffsetCorrectedQuantizedByteVectorValues
|
||||
extends QuantizedByteVectorValues {
|
||||
|
||||
private final QuantizedByteVectorValues in;
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
private final ScalarQuantizer scalarQuantizer, oldScalarQuantizer;
|
||||
|
||||
private OffsetCorrectedQuantizedByteVectorValues(
|
||||
QuantizedByteVectorValues in,
|
||||
VectorSimilarityFunction vectorSimilarityFunction,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
ScalarQuantizer oldScalarQuantizer) {
|
||||
this.in = in;
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
this.scalarQuantizer = scalarQuantizer;
|
||||
this.oldScalarQuantizer = oldScalarQuantizer;
|
||||
}
|
||||
|
||||
@Override
|
||||
float getScoreCorrectionConstant() throws IOException {
|
||||
return scalarQuantizer.recalculateCorrectiveOffset(
|
||||
in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return in.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return in.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return in.vectorValue();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return in.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return in.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return in.advance(target);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,275 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.nio.ByteBuffer;
|
||||
import org.apache.lucene.codecs.lucene90.IndexedDISI;
|
||||
import org.apache.lucene.codecs.lucene95.OrdToDocDISIReaderConfiguration;
|
||||
import org.apache.lucene.store.IndexInput;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.packed.DirectMonotonicReader;
|
||||
|
||||
/**
|
||||
* Read the quantized vector values and their score correction values from the index input. This
|
||||
* supports both iterated and random access.
|
||||
*/
|
||||
abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues
|
||||
implements RandomAccessQuantizedByteVectorValues {
|
||||
|
||||
protected final int dimension;
|
||||
protected final int size;
|
||||
protected final IndexInput slice;
|
||||
protected final byte[] binaryValue;
|
||||
protected final ByteBuffer byteBuffer;
|
||||
protected final int byteSize;
|
||||
protected int lastOrd = -1;
|
||||
protected final float[] scoreCorrectionConstant = new float[1];
|
||||
|
||||
OffHeapQuantizedByteVectorValues(int dimension, int size, IndexInput slice) {
|
||||
this.dimension = dimension;
|
||||
this.size = size;
|
||||
this.slice = slice;
|
||||
this.byteSize = dimension + Float.BYTES;
|
||||
byteBuffer = ByteBuffer.allocate(dimension);
|
||||
binaryValue = byteBuffer.array();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return dimension;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
if (lastOrd == targetOrd) {
|
||||
return binaryValue;
|
||||
}
|
||||
slice.seek((long) targetOrd * byteSize);
|
||||
slice.readBytes(byteBuffer.array(), byteBuffer.arrayOffset(), dimension);
|
||||
slice.readFloats(scoreCorrectionConstant, 0, 1);
|
||||
lastOrd = targetOrd;
|
||||
return binaryValue;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getScoreCorrectionConstant() {
|
||||
return scoreCorrectionConstant[0];
|
||||
}
|
||||
|
||||
static OffHeapQuantizedByteVectorValues load(
|
||||
OrdToDocDISIReaderConfiguration configuration,
|
||||
int dimension,
|
||||
int size,
|
||||
long quantizedVectorDataOffset,
|
||||
long quantizedVectorDataLength,
|
||||
IndexInput vectorData)
|
||||
throws IOException {
|
||||
if (configuration.isEmpty()) {
|
||||
return new EmptyOffHeapVectorValues(dimension);
|
||||
}
|
||||
IndexInput bytesSlice =
|
||||
vectorData.slice(
|
||||
"quantized-vector-data", quantizedVectorDataOffset, quantizedVectorDataLength);
|
||||
if (configuration.isDense()) {
|
||||
return new DenseOffHeapVectorValues(dimension, size, bytesSlice);
|
||||
} else {
|
||||
return new SparseOffHeapVectorValues(configuration, dimension, size, vectorData, bytesSlice);
|
||||
}
|
||||
}
|
||||
|
||||
abstract Bits getAcceptOrds(Bits acceptDocs);
|
||||
|
||||
static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
public DenseOffHeapVectorValues(int dimension, int size, IndexInput slice) {
|
||||
super(dimension, size, slice);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(doc);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
if (target >= size) {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
return doc = target;
|
||||
}
|
||||
|
||||
@Override
|
||||
public DenseOffHeapVectorValues copy() throws IOException {
|
||||
return new DenseOffHeapVectorValues(dimension, size, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return acceptDocs;
|
||||
}
|
||||
}
|
||||
|
||||
private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
private final DirectMonotonicReader ordToDoc;
|
||||
private final IndexedDISI disi;
|
||||
// dataIn was used to init a new IndexedDIS for #randomAccess()
|
||||
private final IndexInput dataIn;
|
||||
private final OrdToDocDISIReaderConfiguration configuration;
|
||||
|
||||
public SparseOffHeapVectorValues(
|
||||
OrdToDocDISIReaderConfiguration configuration,
|
||||
int dimension,
|
||||
int size,
|
||||
IndexInput dataIn,
|
||||
IndexInput slice)
|
||||
throws IOException {
|
||||
super(dimension, size, slice);
|
||||
this.configuration = configuration;
|
||||
this.dataIn = dataIn;
|
||||
this.ordToDoc = configuration.getDirectMonotonicReader(dataIn);
|
||||
this.disi = configuration.getIndexedDISI(dataIn);
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
return vectorValue(disi.index());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return disi.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return disi.nextDoc();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
assert docID() < target;
|
||||
return disi.advance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SparseOffHeapVectorValues copy() throws IOException {
|
||||
return new SparseOffHeapVectorValues(configuration, dimension, size, dataIn, slice.clone());
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
return (int) ordToDoc.get(ord);
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
if (acceptDocs == null) {
|
||||
return null;
|
||||
}
|
||||
return new Bits() {
|
||||
@Override
|
||||
public boolean get(int index) {
|
||||
return acceptDocs.get(ordToDoc(index));
|
||||
}
|
||||
|
||||
@Override
|
||||
public int length() {
|
||||
return size;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
private static class EmptyOffHeapVectorValues extends OffHeapQuantizedByteVectorValues {
|
||||
|
||||
public EmptyOffHeapVectorValues(int dimension) {
|
||||
super(dimension, 0, null);
|
||||
}
|
||||
|
||||
private int doc = -1;
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return super.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return doc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
return advance(doc + 1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return doc = NO_MORE_DOCS;
|
||||
}
|
||||
|
||||
@Override
|
||||
public EmptyOffHeapVectorValues copy() throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public byte[] vectorValue(int targetOrd) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int ordToDoc(int ord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
Bits getAcceptOrds(Bits acceptDocs) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.ByteVectorValues;
|
||||
|
||||
/**
|
||||
* A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for
|
||||
* Scalar quantization scores.
|
||||
*/
|
||||
abstract class QuantizedByteVectorValues extends ByteVectorValues {
|
||||
abstract float getScoreCorrectionConstant() throws IOException;
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.Closeable;
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.Accountable;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
|
||||
/** Quantized vector reader */
|
||||
interface QuantizedVectorsReader extends Closeable, Accountable {
|
||||
|
||||
QuantizedByteVectorValues getQuantizedVectorValues(String fieldName) throws IOException;
|
||||
|
||||
ScalarQuantizer getQuantizationState(String fieldName);
|
||||
}
|
|
@ -0,0 +1,31 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
|
||||
|
||||
/**
|
||||
* Random access values for <code>byte[]</code>, but also includes accessing the score correction
|
||||
* constant for the current vector in the buffer.
|
||||
*/
|
||||
interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues<byte[]> {
|
||||
float getScoreCorrectionConstant();
|
||||
|
||||
@Override
|
||||
RandomAccessQuantizedByteVectorValues copy() throws IOException;
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.ArrayUtil;
|
||||
import org.apache.lucene.util.ScalarQuantizedVectorSimilarity;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
|
||||
/** Quantized vector scorer */
|
||||
final class ScalarQuantizedRandomVectorScorer implements RandomVectorScorer {
|
||||
|
||||
private static float quantizeQuery(
|
||||
float[] query,
|
||||
byte[] quantizedQuery,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
ScalarQuantizer scalarQuantizer) {
|
||||
float[] processedQuery =
|
||||
switch (similarityFunction) {
|
||||
case EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT -> query;
|
||||
case COSINE -> {
|
||||
float[] queryCopy = ArrayUtil.copyOfSubArray(query, 0, query.length);
|
||||
VectorUtil.l2normalize(queryCopy);
|
||||
yield queryCopy;
|
||||
}
|
||||
};
|
||||
return scalarQuantizer.quantize(processedQuery, quantizedQuery, similarityFunction);
|
||||
}
|
||||
|
||||
private final byte[] quantizedQuery;
|
||||
private final float queryOffset;
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final ScalarQuantizedVectorSimilarity similarity;
|
||||
|
||||
ScalarQuantizedRandomVectorScorer(
|
||||
ScalarQuantizedVectorSimilarity similarityFunction,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
byte[] query,
|
||||
float queryOffset) {
|
||||
this.quantizedQuery = query;
|
||||
this.queryOffset = queryOffset;
|
||||
this.similarity = similarityFunction;
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
ScalarQuantizedRandomVectorScorer(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
RandomAccessQuantizedByteVectorValues values,
|
||||
float[] query) {
|
||||
byte[] quantizedQuery = new byte[query.length];
|
||||
float correction = quantizeQuery(query, quantizedQuery, similarityFunction, scalarQuantizer);
|
||||
this.quantizedQuery = quantizedQuery;
|
||||
this.queryOffset = correction;
|
||||
this.similarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier());
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(int node) throws IOException {
|
||||
byte[] storedVectorValue = values.vectorValue(node);
|
||||
float storedVectorCorrection = values.getScoreCorrectionConstant();
|
||||
return similarity.score(
|
||||
quantizedQuery, this.queryOffset, storedVectorValue, storedVectorCorrection);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.ScalarQuantizedVectorSimilarity;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorer;
|
||||
import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier;
|
||||
|
||||
/** Quantized vector scorer supplier */
|
||||
final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier {
|
||||
|
||||
private final RandomAccessQuantizedByteVectorValues values;
|
||||
private final ScalarQuantizedVectorSimilarity similarity;
|
||||
|
||||
ScalarQuantizedRandomVectorScorerSupplier(
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
RandomAccessQuantizedByteVectorValues values) {
|
||||
this.similarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
similarityFunction, scalarQuantizer.getConstantMultiplier());
|
||||
this.values = values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RandomVectorScorer scorer(int ord) throws IOException {
|
||||
final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy();
|
||||
final byte[] queryVector = values.vectorValue(ord);
|
||||
final float queryOffset = values.getScoreCorrectionConstant();
|
||||
return new ScalarQuantizedRandomVectorScorer(similarity, vectorsCopy, queryVector, queryOffset);
|
||||
}
|
||||
}
|
|
@ -180,7 +180,7 @@
|
|||
* of files, recording dimensionally indexed fields, to enable fast numeric range filtering
|
||||
* and large numeric values like BigInteger and BigDecimal (1D) and geographic shape
|
||||
* intersection (2D, 3D).
|
||||
* <li>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}. The
|
||||
* <li>{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}. The
|
||||
* vector format stores numeric vectors in a format optimized for random access and
|
||||
* computation, supporting high-dimensional nearest-neighbor search.
|
||||
* </ul>
|
||||
|
@ -310,10 +310,11 @@
|
|||
* <td>Holds indexed points</td>
|
||||
* </tr>
|
||||
* <tr>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat Vector values}</td>
|
||||
* <td>.vec, .vem</td>
|
||||
* <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data, and
|
||||
* <code>.vem</code> the vector metadata</td>
|
||||
* <td>{@link org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat Vector values}</td>
|
||||
* <td>.vec, .vem, .veq, vex</td>
|
||||
* <td>Holds indexed vectors; <code>.vec</code> files contain the raw vector data,
|
||||
* <code>.vem</code> the vector metadata, <code>.veq</code> the quantized vector data, and <code>.vex</code> the
|
||||
* hnsw graph data.</td>
|
||||
* </tr>
|
||||
* </table>
|
||||
*
|
||||
|
@ -408,6 +409,8 @@
|
|||
* <li>In version 9.5, HNSW graph connections were changed to be delta-encoded with vints.
|
||||
* Additionally, metadata file size improvements were made by delta-encoding nodes by graph
|
||||
* layer and not writing the node ids for the zeroth layer.
|
||||
* <li>In version 9.9, Vector scalar quantization support was added. Allowing the HNSW vector
|
||||
* format to utilize int8 quantized vectors for float32 vector search.
|
||||
* </ul>
|
||||
*
|
||||
* <a id="Limitations"></a>
|
||||
|
|
|
@ -0,0 +1,98 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import static org.apache.lucene.util.VectorUtil.scaleMaxInnerProductScore;
|
||||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
||||
/**
|
||||
* Calculates and adjust the scores correctly for quantized vectors given the scalar quantization
|
||||
* parameters
|
||||
*/
|
||||
public interface ScalarQuantizedVectorSimilarity {
|
||||
|
||||
/**
|
||||
* Creates a {@link ScalarQuantizedVectorSimilarity} from a {@link VectorSimilarityFunction} and
|
||||
* the constant multiplier used for quantization.
|
||||
*
|
||||
* @param sim similarity function
|
||||
* @param constMultiplier constant multiplier used for quantization
|
||||
* @return a {@link ScalarQuantizedVectorSimilarity} that applies the appropriate corrections
|
||||
*/
|
||||
static ScalarQuantizedVectorSimilarity fromVectorSimilarity(
|
||||
VectorSimilarityFunction sim, float constMultiplier) {
|
||||
return switch (sim) {
|
||||
case EUCLIDEAN -> new Euclidean(constMultiplier);
|
||||
case COSINE, DOT_PRODUCT -> new DotProduct(constMultiplier);
|
||||
case MAXIMUM_INNER_PRODUCT -> new MaximumInnerProduct(constMultiplier);
|
||||
};
|
||||
}
|
||||
|
||||
float score(byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset);
|
||||
|
||||
/** Calculates euclidean distance on quantized vectors, applying the appropriate corrections */
|
||||
class Euclidean implements ScalarQuantizedVectorSimilarity {
|
||||
private final float constMultiplier;
|
||||
|
||||
public Euclidean(float constMultiplier) {
|
||||
this.constMultiplier = constMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(
|
||||
byte[] queryVector, float queryVectorOffset, byte[] storedVector, float vectorOffset) {
|
||||
int squareDistance = VectorUtil.squareDistance(storedVector, queryVector);
|
||||
float adjustedDistance = squareDistance * constMultiplier;
|
||||
return 1 / (1f + adjustedDistance);
|
||||
}
|
||||
}
|
||||
|
||||
/** Calculates dot product on quantized vectors, applying the appropriate corrections */
|
||||
class DotProduct implements ScalarQuantizedVectorSimilarity {
|
||||
private final float constMultiplier;
|
||||
|
||||
public DotProduct(float constMultiplier) {
|
||||
this.constMultiplier = constMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(
|
||||
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
|
||||
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
|
||||
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
|
||||
return (1 + adjustedDistance) / 2;
|
||||
}
|
||||
}
|
||||
|
||||
/** Calculates max inner product on quantized vectors, applying the appropriate corrections */
|
||||
class MaximumInnerProduct implements ScalarQuantizedVectorSimilarity {
|
||||
private final float constMultiplier;
|
||||
|
||||
public MaximumInnerProduct(float constMultiplier) {
|
||||
this.constMultiplier = constMultiplier;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score(
|
||||
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
|
||||
int dotProduct = VectorUtil.dotProduct(storedVector, queryVector);
|
||||
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
|
||||
return scaleMaxInnerProductScore(adjustedDistance);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,317 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Random;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
|
||||
/**
|
||||
* Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation.
|
||||
* Scalar quantization works by first calculating the quantiles of the float vector values. The
|
||||
* quantiles are calculated using the configured quantile/confidence interval. The [minQuantile,
|
||||
* maxQuantile] are then used to scale the values into the range [0, 127] and bucketed into the
|
||||
* nearest byte values.
|
||||
*
|
||||
* <h2>How Scalar Quantization Works</h2>
|
||||
*
|
||||
* <p>The basic mathematical equations behind this are fairly straight forward. Given a float vector
|
||||
* `v` and a quantile `q` we can calculate the quantiles of the vector values [minQuantile,
|
||||
* maxQuantile].
|
||||
*
|
||||
* <pre class="prettyprint">
|
||||
* byte = (float - minQuantile) * 127/(maxQuantile - minQuantile)
|
||||
* float = (maxQuantile - minQuantile)/127 * byte + minQuantile
|
||||
* </pre>
|
||||
*
|
||||
* <p>This then means to multiply two float values together (e.g. dot_product) we can do the
|
||||
* following:
|
||||
*
|
||||
* <pre class="prettyprint">
|
||||
* float1 * float2 ~= (byte1 * (maxQuantile - minQuantile)/127 + minQuantile) * (byte2 * (maxQuantile - minQuantile)/127 + minQuantile)
|
||||
* float1 * float2 ~= (byte1 * byte2 * (maxQuantile - minQuantile)^2)/(127^2) + (byte1 * minQuantile * (maxQuantile - minQuantile)/127) + (byte2 * minQuantile * (maxQuantile - minQuantile)/127) + minQuantile^2
|
||||
* let alpha = (maxQuantile - minQuantile)/127
|
||||
* float1 * float2 ~= (byte1 * byte2 * alpha^2) + (byte1 * minQuantile * alpha) + (byte2 * minQuantile * alpha) + minQuantile^2
|
||||
* </pre>
|
||||
*
|
||||
* <p>The expansion for square distance is much simpler:
|
||||
*
|
||||
* <pre class="prettyprint">
|
||||
* square_distance = (float1 - float2)^2
|
||||
* (float1 - float2)^2 ~= (byte1 * alpha + minQuantile - byte2 * alpha - minQuantile)^2
|
||||
* = (alpha*byte1 + minQuantile)^2 + (alpha*byte2 + minQuantile)^2 - 2*(alpha*byte1 + minQuantile)(alpha*byte2 + minQuantile)
|
||||
* this can be simplified to:
|
||||
* = alpha^2 (byte1 - byte2)^2
|
||||
* </pre>
|
||||
*/
|
||||
public class ScalarQuantizer {
|
||||
|
||||
public static final int SCALAR_QUANTIZATION_SAMPLE_SIZE = 25_000;
|
||||
|
||||
private final float alpha;
|
||||
private final float scale;
|
||||
private final float minQuantile, maxQuantile, configuredQuantile;
|
||||
|
||||
/**
|
||||
* @param minQuantile the lower quantile of the distribution
|
||||
* @param maxQuantile the upper quantile of the distribution
|
||||
* @param configuredQuantile The configured quantile/confidence interval used to calculate the
|
||||
* quantiles.
|
||||
*/
|
||||
public ScalarQuantizer(float minQuantile, float maxQuantile, float configuredQuantile) {
|
||||
assert maxQuantile >= minQuantile;
|
||||
this.minQuantile = minQuantile;
|
||||
this.maxQuantile = maxQuantile;
|
||||
this.scale = 127f / (maxQuantile - minQuantile);
|
||||
this.alpha = (maxQuantile - minQuantile) / 127f;
|
||||
this.configuredQuantile = configuredQuantile;
|
||||
}
|
||||
|
||||
/**
|
||||
* Quantize a float vector into a byte vector
|
||||
*
|
||||
* @param src the source vector
|
||||
* @param dest the destination vector
|
||||
* @param similarityFunction the similarity function used to calculate the quantile
|
||||
* @return the corrective offset that needs to be applied to the score
|
||||
*/
|
||||
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
|
||||
assert src.length == dest.length;
|
||||
float correctiveOffset = 0f;
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
float v = src[i];
|
||||
// Make sure the value is within the quantile range, cutting off the tails
|
||||
// see first parenthesis in equation: byte = (float - minQuantile) * 127/(maxQuantile -
|
||||
// minQuantile)
|
||||
float dx = Math.max(minQuantile, Math.min(maxQuantile, src[i])) - minQuantile;
|
||||
// Scale the value to the range [0, 127], this is our quantized value
|
||||
// scale = 127/(maxQuantile - minQuantile)
|
||||
float dxs = scale * dx;
|
||||
// We multiply by `alpha` here to get the quantized value back into the original range
|
||||
// to aid in calculating the corrective offset
|
||||
float dxq = Math.round(dxs) * alpha;
|
||||
// Calculate the corrective offset that needs to be applied to the score
|
||||
// in addition to the `byte * minQuantile * alpha` term in the equation
|
||||
// we add the `(dx - dxq) * dxq` term to account for the fact that the quantized value
|
||||
// will be rounded to the nearest whole number and lose some accuracy
|
||||
// Additionally, we account for the global correction of `minQuantile^2` in the equation
|
||||
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
|
||||
dest[i] = (byte) Math.round(dxs);
|
||||
}
|
||||
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
|
||||
return 0;
|
||||
}
|
||||
return correctiveOffset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Recalculate the old score corrective value given new current quantiles
|
||||
*
|
||||
* @param quantizedVector the old vector
|
||||
* @param oldQuantizer the old quantizer
|
||||
* @param similarityFunction the similarity function used to calculate the quantile
|
||||
* @return the new offset
|
||||
*/
|
||||
public float recalculateCorrectiveOffset(
|
||||
byte[] quantizedVector,
|
||||
ScalarQuantizer oldQuantizer,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
if (similarityFunction.equals(VectorSimilarityFunction.EUCLIDEAN)) {
|
||||
return 0f;
|
||||
}
|
||||
float correctiveOffset = 0f;
|
||||
for (int i = 0; i < quantizedVector.length; i++) {
|
||||
// dequantize the old value in order to recalculate the corrective offset
|
||||
float v = (oldQuantizer.alpha * quantizedVector[i]) + oldQuantizer.minQuantile;
|
||||
float dx = Math.max(minQuantile, Math.min(maxQuantile, v)) - minQuantile;
|
||||
float dxs = scale * dx;
|
||||
float dxq = Math.round(dxs) * alpha;
|
||||
correctiveOffset += minQuantile * (v - minQuantile / 2.0F) + (dx - dxq) * dxq;
|
||||
}
|
||||
return correctiveOffset;
|
||||
}
|
||||
|
||||
/**
|
||||
* Dequantize a byte vector into a float vector
|
||||
*
|
||||
* @param src the source vector
|
||||
* @param dest the destination vector
|
||||
*/
|
||||
public void deQuantize(byte[] src, float[] dest) {
|
||||
assert src.length == dest.length;
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
dest[i] = (alpha * src[i]) + minQuantile;
|
||||
}
|
||||
}
|
||||
|
||||
public float getLowerQuantile() {
|
||||
return minQuantile;
|
||||
}
|
||||
|
||||
public float getUpperQuantile() {
|
||||
return maxQuantile;
|
||||
}
|
||||
|
||||
public float getConfiguredQuantile() {
|
||||
return configuredQuantile;
|
||||
}
|
||||
|
||||
public float getConstantMultiplier() {
|
||||
return alpha * alpha;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "ScalarQuantizer{"
|
||||
+ "minQuantile="
|
||||
+ minQuantile
|
||||
+ ", maxQuantile="
|
||||
+ maxQuantile
|
||||
+ ", configuredQuantile="
|
||||
+ configuredQuantile
|
||||
+ '}';
|
||||
}
|
||||
|
||||
private static final Random random = new Random(42);
|
||||
|
||||
/**
|
||||
* This will read the float vector values and calculate the quantiles. If the number of float
|
||||
* vectors is less than {@link #SCALAR_QUANTIZATION_SAMPLE_SIZE} then all the values will be read
|
||||
* and the quantiles calculated. If the number of float vectors is greater than {@link
|
||||
* #SCALAR_QUANTIZATION_SAMPLE_SIZE} then a random sample of {@link
|
||||
* #SCALAR_QUANTIZATION_SAMPLE_SIZE} will be read and the quantiles calculated.
|
||||
*
|
||||
* @param floatVectorValues the float vector values from which to calculate the quantiles
|
||||
* @param quantile the quantile/confidence interval used to calculate the quantiles
|
||||
* @return A new {@link ScalarQuantizer} instance
|
||||
* @throws IOException if there is an error reading the float vector values
|
||||
*/
|
||||
public static ScalarQuantizer fromVectors(FloatVectorValues floatVectorValues, float quantile)
|
||||
throws IOException {
|
||||
assert 0.9f <= quantile && quantile <= 1f;
|
||||
if (floatVectorValues.size() == 0) {
|
||||
return new ScalarQuantizer(0f, 0f, quantile);
|
||||
}
|
||||
if (quantile == 1f) {
|
||||
float min = Float.POSITIVE_INFINITY;
|
||||
float max = Float.NEGATIVE_INFINITY;
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
for (float v : floatVectorValues.vectorValue()) {
|
||||
min = Math.min(min, v);
|
||||
max = Math.max(max, v);
|
||||
}
|
||||
}
|
||||
return new ScalarQuantizer(min, max, quantile);
|
||||
}
|
||||
int dim = floatVectorValues.dimension();
|
||||
if (floatVectorValues.size() < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
|
||||
int copyOffset = 0;
|
||||
float[] values = new float[floatVectorValues.size() * dim];
|
||||
while (floatVectorValues.nextDoc() != NO_MORE_DOCS) {
|
||||
float[] floatVector = floatVectorValues.vectorValue();
|
||||
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
||||
copyOffset += dim;
|
||||
}
|
||||
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile);
|
||||
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile);
|
||||
}
|
||||
int numFloatVecs = floatVectorValues.size();
|
||||
// Reservoir sample the vector ordinals we want to read
|
||||
float[] values = new float[SCALAR_QUANTIZATION_SAMPLE_SIZE * dim];
|
||||
int[] vectorsToTake = IntStream.range(0, SCALAR_QUANTIZATION_SAMPLE_SIZE).toArray();
|
||||
for (int i = SCALAR_QUANTIZATION_SAMPLE_SIZE; i < numFloatVecs; i++) {
|
||||
int j = random.nextInt(i + 1);
|
||||
if (j < SCALAR_QUANTIZATION_SAMPLE_SIZE) {
|
||||
vectorsToTake[j] = i;
|
||||
}
|
||||
}
|
||||
Arrays.sort(vectorsToTake);
|
||||
int copyOffset = 0;
|
||||
int index = 0;
|
||||
for (int i : vectorsToTake) {
|
||||
while (index <= i) {
|
||||
// We cannot use `advance(docId)` as MergedVectorValues does not support it
|
||||
floatVectorValues.nextDoc();
|
||||
index++;
|
||||
}
|
||||
assert floatVectorValues.docID() != NO_MORE_DOCS;
|
||||
float[] floatVector = floatVectorValues.vectorValue();
|
||||
System.arraycopy(floatVector, 0, values, copyOffset, floatVector.length);
|
||||
copyOffset += dim;
|
||||
}
|
||||
float[] upperAndLower = getUpperAndLowerQuantile(values, quantile);
|
||||
return new ScalarQuantizer(upperAndLower[0], upperAndLower[1], quantile);
|
||||
}
|
||||
|
||||
/**
|
||||
* Takes an array of floats, sorted or not, and returns a minimum and maximum value. These values
|
||||
* are such that they reside on the `(1 - quantile)/2` and `quantile/2` percentiles. Example:
|
||||
* providing floats `[0..100]` and asking for `90` quantiles will return `5` and `95`.
|
||||
*
|
||||
* @param arr array of floats
|
||||
* @param quantileFloat the configured quantile
|
||||
* @return lower and upper quantile values
|
||||
*/
|
||||
static float[] getUpperAndLowerQuantile(float[] arr, float quantileFloat) {
|
||||
assert 0.9f <= quantileFloat && quantileFloat <= 1f;
|
||||
int selectorIndex = (int) (arr.length * (1f - quantileFloat) / 2f + 0.5f);
|
||||
if (selectorIndex > 0) {
|
||||
Selector selector = new FloatSelector(arr);
|
||||
selector.select(0, arr.length, arr.length - selectorIndex);
|
||||
selector.select(0, arr.length - selectorIndex, selectorIndex);
|
||||
}
|
||||
float min = Float.POSITIVE_INFINITY;
|
||||
float max = Float.NEGATIVE_INFINITY;
|
||||
for (int i = selectorIndex; i < arr.length - selectorIndex; i++) {
|
||||
min = Math.min(arr[i], min);
|
||||
max = Math.max(arr[i], max);
|
||||
}
|
||||
return new float[] {min, max};
|
||||
}
|
||||
|
||||
private static class FloatSelector extends IntroSelector {
|
||||
float pivot = Float.NaN;
|
||||
|
||||
private final float[] arr;
|
||||
|
||||
private FloatSelector(float[] arr) {
|
||||
this.arr = arr;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void setPivot(int i) {
|
||||
pivot = arr[i];
|
||||
}
|
||||
|
||||
@Override
|
||||
protected int comparePivot(int j) {
|
||||
return Float.compare(pivot, arr[j]);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void swap(int i, int j) {
|
||||
final float tmp = arr[i];
|
||||
arr[i] = arr[j];
|
||||
arr[j] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.lucene.util.hnsw;
|
||||
|
||||
import java.io.Closeable;
|
||||
|
||||
/**
|
||||
* A supplier that creates {@link RandomVectorScorer} from an ordinal. Caller should be sure to
|
||||
* close after use
|
||||
*/
|
||||
public interface CloseableRandomVectorScorerSupplier
|
||||
extends Closeable, RandomVectorScorerSupplier {}
|
|
@ -13,4 +13,4 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat
|
||||
org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.NoMergePolicy;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
|
||||
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return new Lucene99Codec() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene99HnswVectorsFormat(
|
||||
Lucene99HnswVectorsFormat.DEFAULT_MAX_CONN,
|
||||
Lucene99HnswVectorsFormat.DEFAULT_BEAM_WIDTH,
|
||||
new Lucene99ScalarQuantizedVectorsFormat());
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
public void testQuantizedVectorsWriteAndRead() throws Exception {
|
||||
// create lucene directory with codec
|
||||
int numVectors = 1 + random().nextInt(50);
|
||||
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
||||
int dim = random().nextInt(64) + 1;
|
||||
List<float[]> vectors = new ArrayList<>(numVectors);
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
vectors.add(randomVector(dim));
|
||||
}
|
||||
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile);
|
||||
float[] expectedCorrections = new float[numVectors];
|
||||
byte[][] expectedVectors = new byte[numVectors][];
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
expectedVectors[i] = new byte[dim];
|
||||
expectedCorrections[i] =
|
||||
scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction);
|
||||
}
|
||||
float[] randomlyReusedVector = new float[dim];
|
||||
|
||||
try (Directory dir = newDirectory();
|
||||
IndexWriter w =
|
||||
new IndexWriter(
|
||||
dir,
|
||||
new IndexWriterConfig()
|
||||
.setMaxBufferedDocs(numVectors + 1)
|
||||
.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH)
|
||||
.setMergePolicy(NoMergePolicy.INSTANCE))) {
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
Document doc = new Document();
|
||||
// randomly reuse a vector, this ensures the underlying codec doesn't rely on the array
|
||||
// reference
|
||||
final float[] v;
|
||||
if (random().nextBoolean()) {
|
||||
System.arraycopy(vectors.get(i), 0, randomlyReusedVector, 0, dim);
|
||||
v = randomlyReusedVector;
|
||||
} else {
|
||||
v = vectors.get(i);
|
||||
}
|
||||
doc.add(new KnnFloatVectorField("f", v, similarityFunction));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.commit();
|
||||
try (IndexReader reader = DirectoryReader.open(w)) {
|
||||
LeafReader r = getOnlyLeafReader(reader);
|
||||
if (r instanceof CodecReader codecReader) {
|
||||
KnnVectorsReader knnVectorsReader = codecReader.getVectorReader();
|
||||
if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) {
|
||||
knnVectorsReader = fieldsReader.getFieldReader("f");
|
||||
}
|
||||
if (knnVectorsReader instanceof Lucene99HnswVectorsReader hnswReader) {
|
||||
assertNotNull(hnswReader.getQuantizationState("f"));
|
||||
QuantizedByteVectorValues quantizedByteVectorValues =
|
||||
hnswReader.getQuantizedVectorValues("f");
|
||||
int docId = -1;
|
||||
while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) {
|
||||
byte[] vector = quantizedByteVectorValues.vectorValue();
|
||||
float offset = quantizedByteVectorValues.getScoreCorrectionConstant();
|
||||
for (int i = 0; i < dim; i++) {
|
||||
assertEquals(vector[i], expectedVectors[docId][i]);
|
||||
}
|
||||
assertEquals(offset, expectedCorrections[docId], 0.00001f);
|
||||
}
|
||||
} else {
|
||||
fail("reader is not Lucene99HnswVectorsReader");
|
||||
}
|
||||
} else {
|
||||
fail("reader is not CodecReader");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
FilterCodec customCodec =
|
||||
new FilterCodec("foo", Codec.getDefault()) {
|
||||
@Override
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return new Lucene99HnswVectorsFormat(
|
||||
10, 20, new Lucene99ScalarQuantizedVectorsFormat(0.9f));
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=Lucene99ScalarQuantizedVectorsFormat(name=Lucene99ScalarQuantizedVectorsFormat, quantile=0.9))";
|
||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||
}
|
||||
}
|
|
@ -14,7 +14,7 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene95;
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
|
@ -22,7 +22,7 @@ import org.apache.lucene.codecs.KnnVectorsFormat;
|
|||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
||||
public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
public class TestLucene99HnswVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||
@Override
|
||||
protected Codec getCodec() {
|
||||
return TestUtil.getDefaultCodec();
|
||||
|
@ -33,20 +33,22 @@ public class TestLucene95HnswVectorsFormat extends BaseKnnVectorsFormatTestCase
|
|||
new FilterCodec("foo", Codec.getDefault()) {
|
||||
@Override
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return new Lucene95HnswVectorsFormat(10, 20);
|
||||
return new Lucene99HnswVectorsFormat(10, 20, null);
|
||||
}
|
||||
};
|
||||
String expectedString =
|
||||
"Lucene95HnswVectorsFormat(name=Lucene95HnswVectorsFormat, maxConn=10, beamWidth=20)";
|
||||
"Lucene99HnswVectorsFormat(name=Lucene99HnswVectorsFormat, maxConn=10, beamWidth=20, quantizer=none)";
|
||||
assertEquals(expectedString, customCodec.knnVectorsFormat().toString());
|
||||
}
|
||||
|
||||
public void testLimits() {
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(-1, 20));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(0, 20));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, 0));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, -1));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(512 + 1, 20));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene95HnswVectorsFormat(20, 3201));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(-1, 20, null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(0, 20, null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 0, null));
|
||||
expectThrows(IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, -1, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(512 + 1, 20, null));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> new Lucene99HnswVectorsFormat(20, 3201, null));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.codecs.lucene99;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.ScalarQuantizer;
|
||||
|
||||
public class TestLucene99ScalarQuantizedVectorsFormat extends LuceneTestCase {
|
||||
|
||||
public void testDefaultQuantile() {
|
||||
float defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(99);
|
||||
assertEquals(0.99f, defaultQuantile, 1e-5);
|
||||
defaultQuantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(1);
|
||||
assertEquals(0.9f, defaultQuantile, 1e-5);
|
||||
defaultQuantile =
|
||||
Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(Integer.MAX_VALUE - 2);
|
||||
assertEquals(1.0f, defaultQuantile, 1e-5);
|
||||
}
|
||||
|
||||
public void testLimits() {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(0.89f));
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> new Lucene99ScalarQuantizedVectorsFormat(1.1f));
|
||||
}
|
||||
|
||||
public void testQuantileMergeWithMissing() {
|
||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
||||
quantiles.add(null);
|
||||
List<Integer> segmentSizes = List.of(1, 1, 1, 1);
|
||||
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f));
|
||||
assertNull(Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(List.of(), List.of(), 0.1f));
|
||||
}
|
||||
|
||||
public void testQuantileMerge() {
|
||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
||||
List<Integer> segmentSizes = List.of(1, 1, 1);
|
||||
ScalarQuantizer merged =
|
||||
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
|
||||
assertEquals(0.2f, merged.getLowerQuantile(), 1e-5);
|
||||
assertEquals(0.3f, merged.getUpperQuantile(), 1e-5);
|
||||
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
|
||||
}
|
||||
|
||||
public void testQuantileMergeWithDifferentSegmentSizes() {
|
||||
List<ScalarQuantizer> quantiles = new ArrayList<>();
|
||||
quantiles.add(new ScalarQuantizer(0.1f, 0.2f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.2f, 0.3f, 0.1f));
|
||||
quantiles.add(new ScalarQuantizer(0.3f, 0.4f, 0.1f));
|
||||
List<Integer> segmentSizes = List.of(1, 2, 3);
|
||||
ScalarQuantizer merged =
|
||||
Lucene99ScalarQuantizedVectorsWriter.mergeQuantiles(quantiles, segmentSizes, 0.1f);
|
||||
assertEquals(0.2333333f, merged.getLowerQuantile(), 1e-5);
|
||||
assertEquals(0.3333333f, merged.getUpperQuantile(), 1e-5);
|
||||
assertEquals(0.1f, merged.getConfiguredQuantile(), 1e-5);
|
||||
}
|
||||
}
|
|
@ -29,7 +29,7 @@ import org.apache.lucene.codecs.KnnFieldVectorsWriter;
|
|||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.codecs.KnnVectorsWriter;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
|
@ -170,8 +170,8 @@ public class TestPerFieldKnnVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
|||
try (Directory directory = newDirectory()) {
|
||||
IndexWriterConfig iwc = newIndexWriterConfig(new MockAnalyzer(random()));
|
||||
KnnVectorsFormat format1 =
|
||||
new KnnVectorsFormatMaxDims32(new Lucene95HnswVectorsFormat(16, 100));
|
||||
KnnVectorsFormat format2 = new Lucene95HnswVectorsFormat(16, 100);
|
||||
new KnnVectorsFormatMaxDims32(new Lucene99HnswVectorsFormat(16, 100, null));
|
||||
KnnVectorsFormat format2 = new Lucene99HnswVectorsFormat(16, 100, null);
|
||||
iwc.setCodec(
|
||||
new AssertingCodec() {
|
||||
@Override
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
*/
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.randomIntBetween;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.hnsw.HnswGraphBuilder.randSeed;
|
||||
|
@ -31,8 +32,9 @@ import java.util.concurrent.CountDownLatch;
|
|||
import org.apache.lucene.codecs.Codec;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
|
@ -78,23 +80,15 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
M = random().nextInt(256) + 3;
|
||||
}
|
||||
|
||||
codec =
|
||||
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
|
||||
@Override
|
||||
public KnnVectorsFormat knnVectorsFormat() {
|
||||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
int similarity = random().nextInt(VectorSimilarityFunction.values().length - 1) + 1;
|
||||
similarityFunction = VectorSimilarityFunction.values()[similarity];
|
||||
vectorEncoding = randomVectorEncoding();
|
||||
|
||||
Lucene99ScalarQuantizedVectorsFormat scalarQuantizedVectorsFormat =
|
||||
vectorEncoding.equals(VectorEncoding.FLOAT32) && randomBoolean()
|
||||
? new Lucene99ScalarQuantizedVectorsFormat(1f)
|
||||
: null;
|
||||
|
||||
codec =
|
||||
new FilterCodec(TestUtil.getDefaultCodec().getName(), TestUtil.getDefaultCodec()) {
|
||||
@Override
|
||||
|
@ -102,13 +96,14 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene99HnswVectorsFormat(
|
||||
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, scalarQuantizedVectorsFormat);
|
||||
}
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
if (vectorEncoding == VectorEncoding.FLOAT32) {
|
||||
if (vectorEncoding == VectorEncoding.FLOAT32 && scalarQuantizedVectorsFormat == null) {
|
||||
float32Codec = codec;
|
||||
} else {
|
||||
float32Codec =
|
||||
|
@ -118,7 +113,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH);
|
||||
return new Lucene99HnswVectorsFormat(
|
||||
M, HnswGraphBuilder.DEFAULT_BEAM_WIDTH, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -409,8 +405,8 @@ public class TestKnnGraph extends LuceneTestCase {
|
|||
if (perFieldReader == null) {
|
||||
continue;
|
||||
}
|
||||
Lucene95HnswVectorsReader vectorReader =
|
||||
(Lucene95HnswVectorsReader) perFieldReader.getFieldReader(vectorField);
|
||||
Lucene99HnswVectorsReader vectorReader =
|
||||
(Lucene99HnswVectorsReader) perFieldReader.getFieldReader(vectorField);
|
||||
if (vectorReader == null) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -0,0 +1,216 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import static org.apache.lucene.util.TestScalarQuantizer.fromFloats;
|
||||
import static org.apache.lucene.util.TestScalarQuantizer.randomFloatArray;
|
||||
import static org.apache.lucene.util.TestScalarQuantizer.randomFloats;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
|
||||
|
||||
public void testToEuclidean() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
||||
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN);
|
||||
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.EUCLIDEAN, scalarQuantizer.getConstantMultiplier());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
offsets,
|
||||
query,
|
||||
error,
|
||||
VectorSimilarityFunction.EUCLIDEAN,
|
||||
quantizedSimilarity,
|
||||
scalarQuantizer);
|
||||
}
|
||||
}
|
||||
|
||||
public void testToCosine() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
|
||||
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
||||
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloatsNormalized(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectorsNormalized(
|
||||
scalarQuantizer, floats, quantized, VectorSimilarityFunction.COSINE);
|
||||
float[] query = ArrayUtil.copyOfSubArray(floats[0], 0, dims);
|
||||
VectorUtil.l2normalize(query);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.COSINE, scalarQuantizer.getConstantMultiplier());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
offsets,
|
||||
query,
|
||||
error,
|
||||
VectorSimilarityFunction.COSINE,
|
||||
quantizedSimilarity,
|
||||
scalarQuantizer);
|
||||
}
|
||||
}
|
||||
|
||||
public void testToDotProduct() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
for (float[] fs : floats) {
|
||||
VectorUtil.l2normalize(fs);
|
||||
}
|
||||
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
||||
float error = Math.max((100 - quantile) * 0.01f, 0.01f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT);
|
||||
float[] query = randomFloatArray(dims);
|
||||
VectorUtil.l2normalize(query);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.DOT_PRODUCT, scalarQuantizer.getConstantMultiplier());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
offsets,
|
||||
query,
|
||||
error,
|
||||
VectorSimilarityFunction.DOT_PRODUCT,
|
||||
quantizedSimilarity,
|
||||
scalarQuantizer);
|
||||
}
|
||||
}
|
||||
|
||||
public void testToMaxInnerProduct() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
for (float quantile : new float[] {0.9f, 0.95f, 0.99f, (1 - 1f / (dims + 1)), 1f}) {
|
||||
float error = Math.max((100 - quantile) * 0.5f, 0.5f);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, quantile);
|
||||
byte[][] quantized = new byte[floats.length][];
|
||||
float[] offsets =
|
||||
quantizeVectors(
|
||||
scalarQuantizer, floats, quantized, VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT);
|
||||
float[] query = randomFloatArray(dims);
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity =
|
||||
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
|
||||
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
|
||||
scalarQuantizer.getConstantMultiplier());
|
||||
assertQuantizedScores(
|
||||
floats,
|
||||
quantized,
|
||||
offsets,
|
||||
query,
|
||||
error,
|
||||
VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT,
|
||||
quantizedSimilarity,
|
||||
scalarQuantizer);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertQuantizedScores(
|
||||
float[][] floats,
|
||||
byte[][] quantized,
|
||||
float[] storedOffsets,
|
||||
float[] query,
|
||||
float error,
|
||||
VectorSimilarityFunction similarityFunction,
|
||||
ScalarQuantizedVectorSimilarity quantizedSimilarity,
|
||||
ScalarQuantizer scalarQuantizer) {
|
||||
for (int i = 0; i < floats.length; i++) {
|
||||
float storedOffset = storedOffsets[i];
|
||||
byte[] quantizedQuery = new byte[query.length];
|
||||
float queryOffset = scalarQuantizer.quantize(query, quantizedQuery, similarityFunction);
|
||||
float original = similarityFunction.compare(query, floats[i]);
|
||||
float quantizedScore =
|
||||
quantizedSimilarity.score(quantizedQuery, queryOffset, quantized[i], storedOffset);
|
||||
assertEquals("Not within acceptable error [" + error + "]", original, quantizedScore, error);
|
||||
}
|
||||
}
|
||||
|
||||
private static float[] quantizeVectors(
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
float[][] floats,
|
||||
byte[][] quantized,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
int i = 0;
|
||||
float[] offsets = new float[floats.length];
|
||||
for (float[] v : floats) {
|
||||
quantized[i] = new byte[v.length];
|
||||
offsets[i] = scalarQuantizer.quantize(v, quantized[i], similarityFunction);
|
||||
++i;
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private static float[] quantizeVectorsNormalized(
|
||||
ScalarQuantizer scalarQuantizer,
|
||||
float[][] floats,
|
||||
byte[][] quantized,
|
||||
VectorSimilarityFunction similarityFunction) {
|
||||
int i = 0;
|
||||
float[] offsets = new float[floats.length];
|
||||
for (float[] f : floats) {
|
||||
float[] v = ArrayUtil.copyOfSubArray(f, 0, f.length);
|
||||
VectorUtil.l2normalize(v);
|
||||
quantized[i] = new byte[v.length];
|
||||
offsets[i] = scalarQuantizer.quantize(v, quantized[i], similarityFunction);
|
||||
++i;
|
||||
}
|
||||
return offsets;
|
||||
}
|
||||
|
||||
private static FloatVectorValues fromFloatsNormalized(float[][] floats) {
|
||||
return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats) {
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= floats.length) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
float[] v = ArrayUtil.copyOfSubArray(floats[curDoc], 0, floats[curDoc].length);
|
||||
VectorUtil.l2normalize(v);
|
||||
return v;
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestScalarQuantizer extends LuceneTestCase {
|
||||
|
||||
public void testQuantizeAndDeQuantize() throws IOException {
|
||||
int dims = 128;
|
||||
int numVecs = 100;
|
||||
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1);
|
||||
float[] dequantized = new float[dims];
|
||||
byte[] quantized = new byte[dims];
|
||||
byte[] requantized = new byte[dims];
|
||||
for (int i = 0; i < numVecs; i++) {
|
||||
scalarQuantizer.quantize(floats[i], quantized, similarityFunction);
|
||||
scalarQuantizer.deQuantize(quantized, dequantized);
|
||||
scalarQuantizer.quantize(dequantized, requantized, similarityFunction);
|
||||
for (int j = 0; j < dims; j++) {
|
||||
assertEquals(dequantized[j], floats[i][j], 0.02);
|
||||
assertEquals(quantized[j], requantized[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testQuantiles() {
|
||||
float[] percs = new float[1000];
|
||||
for (int i = 0; i < 1000; i++) {
|
||||
percs[i] = (float) i;
|
||||
}
|
||||
shuffleArray(percs);
|
||||
float[] upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.9f);
|
||||
assertEquals(50f, upperAndLower[0], 1e-7);
|
||||
assertEquals(949f, upperAndLower[1], 1e-7);
|
||||
shuffleArray(percs);
|
||||
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.95f);
|
||||
assertEquals(25f, upperAndLower[0], 1e-7);
|
||||
assertEquals(974f, upperAndLower[1], 1e-7);
|
||||
shuffleArray(percs);
|
||||
upperAndLower = ScalarQuantizer.getUpperAndLowerQuantile(percs, 0.99f);
|
||||
assertEquals(5f, upperAndLower[0], 1e-7);
|
||||
assertEquals(994f, upperAndLower[1], 1e-7);
|
||||
}
|
||||
|
||||
public void testEdgeCase() {
|
||||
float[] upperAndLower =
|
||||
ScalarQuantizer.getUpperAndLowerQuantile(new float[] {1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, 0.9f);
|
||||
assertEquals(1f, upperAndLower[0], 1e-7f);
|
||||
assertEquals(1f, upperAndLower[1], 1e-7f);
|
||||
}
|
||||
|
||||
static void shuffleArray(float[] ar) {
|
||||
for (int i = ar.length - 1; i > 0; i--) {
|
||||
int index = random().nextInt(i + 1);
|
||||
float a = ar[index];
|
||||
ar[index] = ar[i];
|
||||
ar[i] = a;
|
||||
}
|
||||
}
|
||||
|
||||
static float[] randomFloatArray(int dims) {
|
||||
float[] arr = new float[dims];
|
||||
for (int j = 0; j < dims; j++) {
|
||||
arr[j] = random().nextFloat(-1, 1);
|
||||
}
|
||||
return arr;
|
||||
}
|
||||
|
||||
static float[][] randomFloats(int num, int dims) {
|
||||
float[][] floats = new float[num][];
|
||||
for (int i = 0; i < num; i++) {
|
||||
floats[i] = randomFloatArray(dims);
|
||||
}
|
||||
return floats;
|
||||
}
|
||||
|
||||
static FloatVectorValues fromFloats(float[][] floats) {
|
||||
return new TestSimpleFloatVectorValues(floats);
|
||||
}
|
||||
|
||||
static class TestSimpleFloatVectorValues extends FloatVectorValues {
|
||||
protected final float[][] floats;
|
||||
protected int curDoc = -1;
|
||||
|
||||
TestSimpleFloatVectorValues(float[][] values) {
|
||||
this.floats = values;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return floats[0].length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return floats.length;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
if (curDoc == -1 || curDoc >= floats.length) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
return floats[curDoc];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (curDoc >= floats.length) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return curDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
curDoc++;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = target;
|
||||
return docID();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -40,8 +40,8 @@ import java.util.concurrent.TimeoutException;
|
|||
import java.util.stream.Collectors;
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
|
@ -165,7 +165,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, beamWidth);
|
||||
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -237,7 +237,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, beamWidth);
|
||||
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -266,7 +266,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
assertEquals(indexedDoc, ctx.reader().numDocs());
|
||||
assertVectorsEqual(v3, values);
|
||||
HnswGraph graphValues =
|
||||
((Lucene95HnswVectorsReader)
|
||||
((Lucene99HnswVectorsReader)
|
||||
((PerFieldKnnVectorsFormat.FieldsReader)
|
||||
((CodecReader) ctx.reader()).getVectorReader())
|
||||
.getFieldReader("field"))
|
||||
|
@ -298,7 +298,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, beamWidth);
|
||||
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
@ -312,7 +312,7 @@ abstract class HnswGraphTestCase<T> extends LuceneTestCase {
|
|||
return new PerFieldKnnVectorsFormat() {
|
||||
@Override
|
||||
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
|
||||
return new Lucene95HnswVectorsFormat(M, beamWidth);
|
||||
return new Lucene99HnswVectorsFormat(M, beamWidth, null);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ package org.apache.lucene.tests.codecs.vector;
|
|||
|
||||
import org.apache.lucene.codecs.FilterCodec;
|
||||
import org.apache.lucene.codecs.KnnVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.tests.util.TestUtil;
|
||||
|
||||
/**
|
||||
|
@ -32,12 +32,12 @@ public class ConfigurableMCodec extends FilterCodec {
|
|||
|
||||
public ConfigurableMCodec() {
|
||||
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
|
||||
knnVectorsFormat = new Lucene95HnswVectorsFormat(128, 100);
|
||||
knnVectorsFormat = new Lucene99HnswVectorsFormat(128, 100, null);
|
||||
}
|
||||
|
||||
public ConfigurableMCodec(int maxConn) {
|
||||
super("ConfigurableMCodec", TestUtil.getDefaultCodec());
|
||||
knnVectorsFormat = new Lucene95HnswVectorsFormat(maxConn, 100);
|
||||
knnVectorsFormat = new Lucene99HnswVectorsFormat(maxConn, 100, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -712,7 +712,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
|
||||
private VectorSimilarityFunction randomSimilarity() {
|
||||
protected VectorSimilarityFunction randomSimilarity() {
|
||||
return VectorSimilarityFunction.values()[
|
||||
random().nextInt(VectorSimilarityFunction.values().length)];
|
||||
}
|
||||
|
@ -1221,7 +1221,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
iw.updateDocument(idTerm, doc);
|
||||
}
|
||||
|
||||
private float[] randomVector(int dim) {
|
||||
protected float[] randomVector(int dim) {
|
||||
assert dim > 0;
|
||||
float[] v = new float[dim];
|
||||
double squareSum = 0.0;
|
||||
|
@ -1409,6 +1409,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
assertEquals(
|
||||
"encoding=" + vectorEncoding,
|
||||
fieldValuesCheckSum,
|
||||
checksum,
|
||||
vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5);
|
||||
|
|
|
@ -55,8 +55,8 @@ import org.apache.lucene.codecs.PostingsFormat;
|
|||
import org.apache.lucene.codecs.blocktreeords.BlockTreeOrdsPostingsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90DocValuesFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
|
||||
import org.apache.lucene.codecs.lucene95.Lucene95HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99Codec;
|
||||
import org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldDocValuesFormat;
|
||||
import org.apache.lucene.codecs.perfield.PerFieldPostingsFormat;
|
||||
import org.apache.lucene.document.BinaryDocValuesField;
|
||||
|
@ -1322,7 +1322,7 @@ public final class TestUtil {
|
|||
* Lucene.
|
||||
*/
|
||||
public static KnnVectorsFormat getDefaultKnnVectorsFormat() {
|
||||
return new Lucene95HnswVectorsFormat();
|
||||
return new Lucene99HnswVectorsFormat();
|
||||
}
|
||||
|
||||
public static boolean anyFilesExceptWriteLock(Directory dir) throws IOException {
|
||||
|
|
Loading…
Reference in New Issue