Fix ScalarQuantization when used with COSINE similarity (#13615)

When quantizing vectors in a COSINE vector space, we normalize them. However, there is a bug when building the quantizer quantiles and we didn't always use the normalized vectors. Consequently, we would end up with poorly configured quantiles and recall will drop significantly (especially in sensitive cases like int4).

closes: #13614
This commit is contained in:
Benjamin Trent 2024-08-05 12:29:14 -04:00 committed by GitHub
parent 26b46ced07
commit 43c80117dd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 118 additions and 64 deletions

View File

@ -350,6 +350,9 @@ Bug Fixes
* GITHUB#13553: Correct RamUsageEstimate for scalar quantized knn vector formats so that raw vectors are correctly
accounted for. (Ben Trent)
* GITHUB#13615: Correct scalar quantization when used in conjunction with COSINE similarity. Vectors are normalized
before quantization to ensure the cosine similarity is correctly calculated. (Ben Trent)
Other
--------------------
(No changes)

View File

@ -677,6 +677,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
Float confidenceInterval,
byte bits)
throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues);
vectorSimilarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
}
if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) {
return ScalarQuantizer.fromVectorsAutoInterval(
floatVectorValues, vectorSimilarityFunction, numVectors, bits);
@ -797,10 +801,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
if (floatVectors.size() == 0) {
return new ScalarQuantizer(0, 0, bits);
}
FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize);
ScalarQuantizer quantizer =
buildScalarQuantizer(
floatVectorValues,
new FloatVectorWrapper(floatVectors),
floatVectors.size(),
fieldInfo.getVectorSimilarityFunction(),
confidenceInterval,
@ -851,14 +854,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
static class FloatVectorWrapper extends FloatVectorValues {
private final List<float[]> vectorList;
private final float[] copy;
private final boolean normalize;
protected int curDoc = -1;
FloatVectorWrapper(List<float[]> vectorList, boolean normalize) {
FloatVectorWrapper(List<float[]> vectorList) {
this.vectorList = vectorList;
this.copy = new float[vectorList.get(0).length];
this.normalize = normalize;
}
@Override
@ -876,11 +875,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
if (curDoc == -1 || curDoc >= vectorList.size()) {
throw new IOException("Current doc not set or too many iterations");
}
if (normalize) {
System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
return copy;
}
return vectorList.get(curDoc);
}
@ -949,13 +943,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
// quantization?
|| scalarQuantizer.getBits() <= 4
|| shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) {
FloatVectorValues toQuantize =
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name);
if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) {
toQuantize = new NormalizedFloatVectorValues(toQuantize);
}
sub =
new QuantizedByteVectorValueSub(
mergeState.docMaps[i],
new QuantizedFloatVectorValues(
mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name),
fieldInfo.getVectorSimilarityFunction(),
scalarQuantizer));
toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer));
} else {
sub =
new QuantizedByteVectorValueSub(
@ -1042,7 +1039,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
private final FloatVectorValues values;
private final ScalarQuantizer quantizer;
private final byte[] quantizedVector;
private final float[] normalizedVector;
private float offsetValue = 0f;
private final VectorSimilarityFunction vectorSimilarityFunction;
@ -1055,11 +1051,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
this.quantizer = quantizer;
this.quantizedVector = new byte[values.dimension()];
this.vectorSimilarityFunction = vectorSimilarityFunction;
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
this.normalizedVector = new float[values.dimension()];
} else {
this.normalizedVector = null;
}
}
@Override
@ -1111,15 +1102,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
}
private void quantize() throws IOException {
if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
offsetValue =
quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction);
} else {
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
offsetValue =
quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction);
}
}
@ -1216,4 +1200,60 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite
throw new UnsupportedOperationException();
}
}
static final class NormalizedFloatVectorValues extends FloatVectorValues {
private final FloatVectorValues values;
private final float[] normalizedVector;
int curDoc = -1;
public NormalizedFloatVectorValues(FloatVectorValues values) {
this.values = values;
this.normalizedVector = new float[values.dimension()];
}
@Override
public int dimension() {
return values.dimension();
}
@Override
public int size() {
return values.size();
}
@Override
public float[] vectorValue() throws IOException {
return normalizedVector;
}
@Override
public VectorScorer scorer(float[] query) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public int docID() {
return values.docID();
}
@Override
public int nextDoc() throws IOException {
curDoc = values.nextDoc();
if (curDoc != NO_MORE_DOCS) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
}
return curDoc;
}
@Override
public int advance(int target) throws IOException {
curDoc = values.advance(target);
if (curDoc != NO_MORE_DOCS) {
System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length);
VectorUtil.l2normalize(normalizedVector);
}
return curDoc;
}
}
}

View File

@ -47,6 +47,8 @@ import org.apache.lucene.internal.vectorization.VectorizationProvider;
*/
public final class VectorUtil {
private static final float EPSILON = 1e-4f;
private static final VectorUtilSupport IMPL =
VectorizationProvider.getInstance().getVectorUtilSupport();
@ -121,6 +123,11 @@ public final class VectorUtil {
return v;
}
public static boolean isUnitVector(float[] v) {
double l1norm = IMPL.dotProduct(v, v);
return Math.abs(l1norm - 1.0d) <= EPSILON;
}
/**
* Modifies the argument to be unit length, dividing by its l2-norm.
*
@ -138,7 +145,7 @@ public final class VectorUtil {
return v;
}
}
if (Math.abs(l1norm - 1.0d) <= 1e-5) {
if (Math.abs(l1norm - 1.0d) <= EPSILON) {
return v;
}
int dim = v.length;

View File

@ -30,6 +30,7 @@ import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.IntroSelector;
import org.apache.lucene.util.Selector;
import org.apache.lucene.util.VectorUtil;
/**
* Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation.
@ -113,6 +114,7 @@ public class ScalarQuantizer {
*/
public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) {
assert src.length == dest.length;
assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src);
float correction = 0;
for (int i = 0; i < src.length; i++) {
correction += quantizeFloat(src[i], dest, i);
@ -332,6 +334,7 @@ public class ScalarQuantizer {
int totalVectorCount,
byte bits)
throws IOException {
assert function != VectorSimilarityFunction.COSINE;
if (totalVectorCount == 0) {
return new ScalarQuantizer(0f, 0f, bits);
}

View File

@ -35,6 +35,7 @@ import org.apache.lucene.document.Document;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.CodecReader;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
@ -127,7 +128,6 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity();
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1;
if (dim % 2 == 1) {
dim++;
@ -136,25 +136,16 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
for (int i = 0; i < numVectors; i++) {
vectors.add(randomVector(dim));
}
FloatVectorValues toQuantize =
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors);
ScalarQuantizer scalarQuantizer =
confidenceInterval != null && confidenceInterval == 0f
? ScalarQuantizer.fromVectorsAutoInterval(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
similarityFunction,
numVectors,
(byte) bits)
: ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
confidenceInterval == null
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
: confidenceInterval,
numVectors,
(byte) bits);
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
toQuantize, numVectors, similarityFunction, confidenceInterval, (byte) bits);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {
float[] vector = vectors.get(i);
if (normalize) {
if (similarityFunction == VectorSimilarityFunction.COSINE) {
float[] copy = new float[vector.length];
System.arraycopy(vector, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);

View File

@ -114,19 +114,12 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm
vectors.add(randomVector(dim));
}
ScalarQuantizer scalarQuantizer =
confidenceInterval != null && confidenceInterval == 0f
? ScalarQuantizer.fromVectorsAutoInterval(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
similarityFunction,
numVectors,
(byte) bits)
: ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
confidenceInterval == null
? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim)
: confidenceInterval,
numVectors,
(byte) bits);
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors),
numVectors,
similarityFunction,
confidenceInterval,
(byte) bits);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {

View File

@ -23,6 +23,7 @@ import java.util.List;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase {
@ -87,12 +88,12 @@ public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase {
for (int i = 0; i < 30; i++) {
float[] vector = new float[] {i, i + 1, i + 2, i + 3};
vectors.add(vector);
if (vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT) {
VectorUtil.l2normalize(vector);
}
}
FloatVectorValues vectorValues =
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(
vectors,
vectorSimilarityFunction == VectorSimilarityFunction.COSINE
|| vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT);
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors);
ScalarQuantizer scalarQuantizer =
Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer(
vectorValues, 30, vectorSimilarityFunction, confidenceInterval, bits);

View File

@ -25,6 +25,7 @@ import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
public class TestScalarQuantizer extends LuceneTestCase {
@ -33,8 +34,16 @@ public class TestScalarQuantizer extends LuceneTestCase {
int dims = random().nextInt(9) + 1;
int numVecs = random().nextInt(9) + 10;
float[][] floats = randomFloats(numVecs, dims);
if (function == VectorSimilarityFunction.COSINE) {
for (float[] v : floats) {
VectorUtil.l2normalize(v);
}
}
for (byte bits : new byte[] {4, 7}) {
FloatVectorValues floatVectorValues = fromFloats(floats);
if (function == VectorSimilarityFunction.COSINE) {
function = VectorSimilarityFunction.DOT_PRODUCT;
}
ScalarQuantizer scalarQuantizer =
random().nextBoolean()
? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits)
@ -63,11 +72,15 @@ public class TestScalarQuantizer extends LuceneTestCase {
expectThrows(
IllegalStateException.class,
() -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits));
VectorSimilarityFunction actualFunction =
function == VectorSimilarityFunction.COSINE
? VectorSimilarityFunction.DOT_PRODUCT
: function;
expectThrows(
IllegalStateException.class,
() ->
ScalarQuantizer.fromVectorsAutoInterval(
floatVectorValues, function, numVecs, bits));
floatVectorValues, actualFunction, numVecs, bits));
}
}
}
@ -185,6 +198,9 @@ public class TestScalarQuantizer extends LuceneTestCase {
VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT;
float[][] floats = randomFloats(numVecs, dims);
for (float[] v : floats) {
VectorUtil.l2normalize(v);
}
FloatVectorValues floatVectorValues = fromFloats(floats);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectorsAutoInterval(