mirror of https://github.com/apache/lucene.git
Fix ScalarQuantization when used with COSINE similarity (#13615)
When quantizing vectors in a COSINE vector space, we normalize them. However, there is a bug when building the quantizer quantiles and we didn't always use the normalized vectors. Consequently, we would end up with poorly configured quantiles and recall will drop significantly (especially in sensitive cases like int4). closes: #13614
This commit is contained in:
parent
26b46ced07
commit
43c80117dd
|
@ -350,6 +350,9 @@ Bug Fixes
|
|||
* GITHUB#13553: Correct RamUsageEstimate for scalar quantized knn vector formats so that raw vectors are correctly
|
||||
accounted for. (Ben Trent)
|
||||
|
||||
* GITHUB#13615: Correct scalar quantization when used in conjunction with COSINE similarity. Vectors are normalized
|
||||
before quantization to ensure the cosine similarity is correctly calculated. (Ben Trent)
|
||||
|
||||
Other
|
||||
--------------------
|
||||
(No changes)
|
||||
|
|
|
@ -677,6 +677,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
Float confidenceInterval,
|
||||
byte bits)
|
||||
throws IOException {
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
|
||||
vectorSimilarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
}
|
||||
if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) {
|
||||
return ScalarQuantizer.fromVectorsAutoInterval(
|
||||
floatVectorValues, vectorSimilarityFunction, numVectors, bits);
|
||||
|
@ -797,10 +801,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
if (floatVectors.size() == 0) {
|
||||
return new ScalarQuantizer(0, 0, bits);
|
||||
}
|
||||
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
|
||||
ScalarQuantizer quantizer =
|
||||
buildScalarQuantizer(
|
||||
floatVectorValues,
|
||||
new FloatVectorWrapper(floatVectors),
|
||||
floatVectors.size(),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
confidenceInterval,
|
||||
|
@ -851,14 +854,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
|
||||
static class FloatVectorWrapper extends FloatVectorValues {
|
||||
private final List<float[]> vectorList;
|
||||
private final float[] copy;
|
||||
private final boolean normalize;
|
||||
protected int curDoc = -1;
|
||||
|
||||
FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
|
||||
FloatVectorWrapper(List<float[]> vectorList) {
|
||||
this.vectorList = vectorList;
|
||||
this.copy = new float[vectorList.get(0).length];
|
||||
this.normalize = normalize;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -876,11 +875,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
if (curDoc == -1 || curDoc >= vectorList.size()) {
|
||||
throw new IOException("Current doc not set or too many iterations");
|
||||
}
|
||||
if (normalize) {
|
||||
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
return copy;
|
||||
}
|
||||
return vectorList.get(curDoc);
|
||||
}
|
||||
|
||||
|
@ -949,13 +943,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
// quantization?
|
||||
|| scalarQuantizer.getBits() <= 4
|
||||
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
|
||||
FloatVectorValues toQuantize =
|
||||
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
|
||||
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
|
||||
toQuantize = new NormalizedFloatVectorValues(toQuantize);
|
||||
}
|
||||
sub =
|
||||
new QuantizedByteVectorValueSub(
|
||||
mergeState.docMaps[i],
|
||||
new QuantizedFloatVectorValues(
|
||||
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name),
|
||||
fieldInfo.getVectorSimilarityFunction(),
|
||||
scalarQuantizer));
|
||||
toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer));
|
||||
} else {
|
||||
sub =
|
||||
new QuantizedByteVectorValueSub(
|
||||
|
@ -1042,7 +1039,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
private final FloatVectorValues values;
|
||||
private final ScalarQuantizer quantizer;
|
||||
private final byte[] quantizedVector;
|
||||
private final float[] normalizedVector;
|
||||
private float offsetValue = 0f;
|
||||
|
||||
private final VectorSimilarityFunction vectorSimilarityFunction;
|
||||
|
@ -1055,11 +1051,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
this.quantizer = quantizer;
|
||||
this.quantizedVector = new byte[values.dimension()];
|
||||
this.vectorSimilarityFunction = vectorSimilarityFunction;
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
this.normalizedVector = new float[values.dimension()];
|
||||
} else {
|
||||
this.normalizedVector = null;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -1111,15 +1102,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
}
|
||||
|
||||
private void quantize() throws IOException {
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
offsetValue =
|
||||
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
|
||||
} else {
|
||||
offsetValue =
|
||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||
}
|
||||
offsetValue =
|
||||
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1216,4 +1200,60 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
|
|||
throw new UnsupportedOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
static final class NormalizedFloatVectorValues extends FloatVectorValues {
|
||||
private final FloatVectorValues values;
|
||||
private final float[] normalizedVector;
|
||||
int curDoc = -1;
|
||||
|
||||
public NormalizedFloatVectorValues(FloatVectorValues values) {
|
||||
this.values = values;
|
||||
this.normalizedVector = new float[values.dimension()];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int dimension() {
|
||||
return values.dimension();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return values.size();
|
||||
}
|
||||
|
||||
@Override
|
||||
public float[] vectorValue() throws IOException {
|
||||
return normalizedVector;
|
||||
}
|
||||
|
||||
@Override
|
||||
public VectorScorer scorer(float[] query) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return values.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() throws IOException {
|
||||
curDoc = values.nextDoc();
|
||||
if (curDoc != NO_MORE_DOCS) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
}
|
||||
return curDoc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
curDoc = values.advance(target);
|
||||
if (curDoc != NO_MORE_DOCS) {
|
||||
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
|
||||
VectorUtil.l2normalize(normalizedVector);
|
||||
}
|
||||
return curDoc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,6 +47,8 @@ import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
|||
*/
|
||||
public final class VectorUtil {
|
||||
|
||||
private static final float EPSILON = 1e-4f;
|
||||
|
||||
private static final VectorUtilSupport IMPL =
|
||||
VectorizationProvider.getInstance().getVectorUtilSupport();
|
||||
|
||||
|
@ -121,6 +123,11 @@ public final class VectorUtil {
|
|||
return v;
|
||||
}
|
||||
|
||||
public static boolean isUnitVector(float[] v) {
|
||||
double l1norm = IMPL.dotProduct(v, v);
|
||||
return Math.abs(l1norm - 1.0d) <= EPSILON;
|
||||
}
|
||||
|
||||
/**
|
||||
* Modifies the argument to be unit length, dividing by its l2-norm.
|
||||
*
|
||||
|
@ -138,7 +145,7 @@ public final class VectorUtil {
|
|||
return v;
|
||||
}
|
||||
}
|
||||
if (Math.abs(l1norm - 1.0d) <= 1e-5) {
|
||||
if (Math.abs(l1norm - 1.0d) <= EPSILON) {
|
||||
return v;
|
||||
}
|
||||
int dim = v.length;
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.lucene.search.HitQueue;
|
|||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.util.IntroSelector;
|
||||
import org.apache.lucene.util.Selector;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation.
|
||||
|
@ -113,6 +114,7 @@ public class ScalarQuantizer {
|
|||
*/
|
||||
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
|
||||
assert src.length == dest.length;
|
||||
assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src);
|
||||
float correction = 0;
|
||||
for (int i = 0; i < src.length; i++) {
|
||||
correction += quantizeFloat(src[i], dest, i);
|
||||
|
@ -332,6 +334,7 @@ public class ScalarQuantizer {
|
|||
int totalVectorCount,
|
||||
byte bits)
|
||||
throws IOException {
|
||||
assert function != VectorSimilarityFunction.COSINE;
|
||||
if (totalVectorCount == 0) {
|
||||
return new ScalarQuantizer(0f, 0f, bits);
|
||||
}
|
||||
|
|
|
@ -35,6 +35,7 @@ import org.apache.lucene.document.Document;
|
|||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.CodecReader;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
|
@ -127,7 +128,6 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
// create lucene directory with codec
|
||||
int numVectors = 1 + random().nextInt(50);
|
||||
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
||||
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
|
||||
int dim = random().nextInt(64) + 1;
|
||||
if (dim % 2 == 1) {
|
||||
dim++;
|
||||
|
@ -136,25 +136,16 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
|||
for (int i = 0; i < numVectors; i++) {
|
||||
vectors.add(randomVector(dim));
|
||||
}
|
||||
FloatVectorValues toQuantize =
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
confidenceInterval != null && confidenceInterval == 0f
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
similarityFunction,
|
||||
numVectors,
|
||||
(byte) bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
confidenceInterval == null
|
||||
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
|
||||
: confidenceInterval,
|
||||
numVectors,
|
||||
(byte) bits);
|
||||
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
|
||||
toQuantize, numVectors, similarityFunction, confidenceInterval, (byte) bits);
|
||||
float[] expectedCorrections = new float[numVectors];
|
||||
byte[][] expectedVectors = new byte[numVectors][];
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
float[] vector = vectors.get(i);
|
||||
if (normalize) {
|
||||
if (similarityFunction == VectorSimilarityFunction.COSINE) {
|
||||
float[] copy = new float[vector.length];
|
||||
System.arraycopy(vector, 0, copy, 0, copy.length);
|
||||
VectorUtil.l2normalize(copy);
|
||||
|
|
|
@ -114,19 +114,12 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
|
|||
vectors.add(randomVector(dim));
|
||||
}
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
confidenceInterval != null && confidenceInterval == 0f
|
||||
? ScalarQuantizer.fromVectorsAutoInterval(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
similarityFunction,
|
||||
numVectors,
|
||||
(byte) bits)
|
||||
: ScalarQuantizer.fromVectors(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||
confidenceInterval == null
|
||||
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
|
||||
: confidenceInterval,
|
||||
numVectors,
|
||||
(byte) bits);
|
||||
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors),
|
||||
numVectors,
|
||||
similarityFunction,
|
||||
confidenceInterval,
|
||||
(byte) bits);
|
||||
float[] expectedCorrections = new float[numVectors];
|
||||
byte[][] expectedVectors = new byte[numVectors][];
|
||||
for (int i = 0; i < numVectors; i++) {
|
||||
|
|
|
@ -23,6 +23,7 @@ import java.util.List;
|
|||
import org.apache.lucene.index.FloatVectorValues;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
import org.apache.lucene.util.quantization.ScalarQuantizer;
|
||||
|
||||
public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase {
|
||||
|
@ -87,12 +88,12 @@ public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase {
|
|||
for (int i = 0; i < 30; i++) {
|
||||
float[] vector = new float[] {i, i + 1, i + 2, i + 3};
|
||||
vectors.add(vector);
|
||||
if (vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
|
||||
VectorUtil.l2normalize(vector);
|
||||
}
|
||||
}
|
||||
FloatVectorValues vectorValues =
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(
|
||||
vectors,
|
||||
vectorSimilarityFunction == VectorSimilarityFunction.COSINE
|
||||
|| vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT);
|
||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
|
||||
vectorValues, 30, vectorSimilarityFunction, confidenceInterval, bits);
|
||||
|
|
|
@ -25,6 +25,7 @@ import org.apache.lucene.index.FloatVectorValues;
|
|||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.search.VectorScorer;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
public class TestScalarQuantizer extends LuceneTestCase {
|
||||
|
||||
|
@ -33,8 +34,16 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
int dims = random().nextInt(9) + 1;
|
||||
int numVecs = random().nextInt(9) + 10;
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
if (function == VectorSimilarityFunction.COSINE) {
|
||||
for (float[] v : floats) {
|
||||
VectorUtil.l2normalize(v);
|
||||
}
|
||||
}
|
||||
for (byte bits : new byte[] {4, 7}) {
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
if (function == VectorSimilarityFunction.COSINE) {
|
||||
function = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
}
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
random().nextBoolean()
|
||||
? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits)
|
||||
|
@ -63,11 +72,15 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
expectThrows(
|
||||
IllegalStateException.class,
|
||||
() -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits));
|
||||
VectorSimilarityFunction actualFunction =
|
||||
function == VectorSimilarityFunction.COSINE
|
||||
? VectorSimilarityFunction.DOT_PRODUCT
|
||||
: function;
|
||||
expectThrows(
|
||||
IllegalStateException.class,
|
||||
() ->
|
||||
ScalarQuantizer.fromVectorsAutoInterval(
|
||||
floatVectorValues, function, numVecs, bits));
|
||||
floatVectorValues, actualFunction, numVecs, bits));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -185,6 +198,9 @@ public class TestScalarQuantizer extends LuceneTestCase {
|
|||
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
|
||||
|
||||
float[][] floats = randomFloats(numVecs, dims);
|
||||
for (float[] v : floats) {
|
||||
VectorUtil.l2normalize(v);
|
||||
}
|
||||
FloatVectorValues floatVectorValues = fromFloats(floats);
|
||||
ScalarQuantizer scalarQuantizer =
|
||||
ScalarQuantizer.fromVectorsAutoInterval(
|
||||
|
|
Loading…
Reference in New Issue