From 43c80117dd51a01e0585242a89de2126a3fea059 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Mon, 5 Aug 2024 12:29:14 -0400 Subject: [PATCH] Fix ScalarQuantization when used with COSINE similarity (#13615) When quantizing vectors in a COSINE vector space, we normalize them. However, there is a bug when building the quantizer quantiles and we didn't always use the normalized vectors. Consequently, we would end up with poorly configured quantiles and recall will drop significantly (especially in sensitive cases like int4). closes: #13614 --- lucene/CHANGES.txt | 3 + .../Lucene99ScalarQuantizedVectorsWriter.java | 100 ++++++++++++------ .../org/apache/lucene/util/VectorUtil.java | 9 +- .../util/quantization/ScalarQuantizer.java | 3 + ...estLucene99HnswQuantizedVectorsFormat.java | 21 ++-- ...tLucene99ScalarQuantizedVectorsFormat.java | 19 ++-- ...tLucene99ScalarQuantizedVectorsWriter.java | 9 +- .../quantization/TestScalarQuantizer.java | 18 +++- 8 files changed, 118 insertions(+), 64 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index fb91b89a2a4..739887b0b91 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -350,6 +350,9 @@ Bug Fixes * GITHUB#13553: Correct RamUsageEstimate for scalar quantized knn vector formats so that raw vectors are correctly accounted for. (Ben Trent) +* GITHUB#13615: Correct scalar quantization when used in conjunction with COSINE similarity. Vectors are normalized + before quantization to ensure the cosine similarity is correctly calculated. (Ben Trent) + Other -------------------- (No changes) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 311f2df435e..e477fec75e5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -677,6 +677,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite Float confidenceInterval, byte bits) throws IOException { + if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { + floatVectorValues = new NormalizedFloatVectorValues(floatVectorValues); + vectorSimilarityFunction = VectorSimilarityFunction.DOT_PRODUCT; + } if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) { return ScalarQuantizer.fromVectorsAutoInterval( floatVectorValues, vectorSimilarityFunction, numVectors, bits); @@ -797,10 +801,9 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite if (floatVectors.size() == 0) { return new ScalarQuantizer(0, 0, bits); } - FloatVectorValues floatVectorValues = new FloatVectorWrapper(floatVectors, normalize); ScalarQuantizer quantizer = buildScalarQuantizer( - floatVectorValues, + new FloatVectorWrapper(floatVectors), floatVectors.size(), fieldInfo.getVectorSimilarityFunction(), confidenceInterval, @@ -851,14 +854,10 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite static class FloatVectorWrapper extends FloatVectorValues { private final List vectorList; - private final float[] copy; - private final boolean normalize; protected int curDoc = -1; - FloatVectorWrapper(List vectorList, boolean normalize) { + FloatVectorWrapper(List vectorList) { this.vectorList = vectorList; - this.copy = new float[vectorList.get(0).length]; - this.normalize = normalize; } @Override @@ -876,11 +875,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite if (curDoc == -1 || curDoc >= vectorList.size()) { throw new IOException("Current doc not set or too many iterations"); } - if (normalize) { - System.arraycopy(vectorList.get(curDoc), 0, copy, 0, copy.length); - VectorUtil.l2normalize(copy); - return copy; - } return vectorList.get(curDoc); } @@ -949,13 +943,16 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite // quantization? || scalarQuantizer.getBits() <= 4 || shouldRequantize(reader.getQuantizationState(fieldInfo.name), scalarQuantizer)) { + FloatVectorValues toQuantize = + mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name); + if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { + toQuantize = new NormalizedFloatVectorValues(toQuantize); + } sub = new QuantizedByteVectorValueSub( mergeState.docMaps[i], new QuantizedFloatVectorValues( - mergeState.knnVectorsReaders[i].getFloatVectorValues(fieldInfo.name), - fieldInfo.getVectorSimilarityFunction(), - scalarQuantizer)); + toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer)); } else { sub = new QuantizedByteVectorValueSub( @@ -1042,7 +1039,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite private final FloatVectorValues values; private final ScalarQuantizer quantizer; private final byte[] quantizedVector; - private final float[] normalizedVector; private float offsetValue = 0f; private final VectorSimilarityFunction vectorSimilarityFunction; @@ -1055,11 +1051,6 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite this.quantizer = quantizer; this.quantizedVector = new byte[values.dimension()]; this.vectorSimilarityFunction = vectorSimilarityFunction; - if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { - this.normalizedVector = new float[values.dimension()]; - } else { - this.normalizedVector = null; - } } @Override @@ -1111,15 +1102,8 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite } private void quantize() throws IOException { - if (vectorSimilarityFunction == VectorSimilarityFunction.COSINE) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - offsetValue = - quantizer.quantize(normalizedVector, quantizedVector, vectorSimilarityFunction); - } else { - offsetValue = - quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); - } + offsetValue = + quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); } } @@ -1216,4 +1200,60 @@ public final class Lucene99ScalarQuantizedVectorsWriter extends FlatVectorsWrite throw new UnsupportedOperationException(); } } + + static final class NormalizedFloatVectorValues extends FloatVectorValues { + private final FloatVectorValues values; + private final float[] normalizedVector; + int curDoc = -1; + + public NormalizedFloatVectorValues(FloatVectorValues values) { + this.values = values; + this.normalizedVector = new float[values.dimension()]; + } + + @Override + public int dimension() { + return values.dimension(); + } + + @Override + public int size() { + return values.size(); + } + + @Override + public float[] vectorValue() throws IOException { + return normalizedVector; + } + + @Override + public VectorScorer scorer(float[] query) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int docID() { + return values.docID(); + } + + @Override + public int nextDoc() throws IOException { + curDoc = values.nextDoc(); + if (curDoc != NO_MORE_DOCS) { + System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + } + return curDoc; + } + + @Override + public int advance(int target) throws IOException { + curDoc = values.advance(target); + if (curDoc != NO_MORE_DOCS) { + System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); + } + return curDoc; + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index f122ae95544..e1c3978cff3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -47,6 +47,8 @@ import org.apache.lucene.internal.vectorization.VectorizationProvider; */ public final class VectorUtil { + private static final float EPSILON = 1e-4f; + private static final VectorUtilSupport IMPL = VectorizationProvider.getInstance().getVectorUtilSupport(); @@ -121,6 +123,11 @@ public final class VectorUtil { return v; } + public static boolean isUnitVector(float[] v) { + double l1norm = IMPL.dotProduct(v, v); + return Math.abs(l1norm - 1.0d) <= EPSILON; + } + /** * Modifies the argument to be unit length, dividing by its l2-norm. * @@ -138,7 +145,7 @@ public final class VectorUtil { return v; } } - if (Math.abs(l1norm - 1.0d) <= 1e-5) { + if (Math.abs(l1norm - 1.0d) <= EPSILON) { return v; } int dim = v.length; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index fb07e005571..44c0ac5aca4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -30,6 +30,7 @@ import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.util.IntroSelector; import org.apache.lucene.util.Selector; +import org.apache.lucene.util.VectorUtil; /** * Will scalar quantize float vectors into `int8` byte values. This is a lossy transformation. @@ -113,6 +114,7 @@ public class ScalarQuantizer { */ public float quantize(float[] src, byte[] dest, VectorSimilarityFunction similarityFunction) { assert src.length == dest.length; + assert similarityFunction != VectorSimilarityFunction.COSINE || VectorUtil.isUnitVector(src); float correction = 0; for (int i = 0; i < src.length; i++) { correction += quantizeFloat(src[i], dest, i); @@ -332,6 +334,7 @@ public class ScalarQuantizer { int totalVectorCount, byte bits) throws IOException { + assert function != VectorSimilarityFunction.COSINE; if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 8e69e833b98..3098ca7fbf3 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -35,6 +35,7 @@ import org.apache.lucene.document.Document; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.CodecReader; import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -127,7 +128,6 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat // create lucene directory with codec int numVectors = 1 + random().nextInt(50); VectorSimilarityFunction similarityFunction = randomSimilarity(); - boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE; int dim = random().nextInt(64) + 1; if (dim % 2 == 1) { dim++; @@ -136,25 +136,16 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat for (int i = 0; i < numVectors; i++) { vectors.add(randomVector(dim)); } + FloatVectorValues toQuantize = + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = - confidenceInterval != null && confidenceInterval == 0f - ? ScalarQuantizer.fromVectorsAutoInterval( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), - similarityFunction, - numVectors, - (byte) bits) - : ScalarQuantizer.fromVectors( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), - confidenceInterval == null - ? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim) - : confidenceInterval, - numVectors, - (byte) bits); + Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( + toQuantize, numVectors, similarityFunction, confidenceInterval, (byte) bits); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { float[] vector = vectors.get(i); - if (normalize) { + if (similarityFunction == VectorSimilarityFunction.COSINE) { float[] copy = new float[vector.length]; System.arraycopy(vector, 0, copy, 0, copy.length); VectorUtil.l2normalize(copy); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index b221cb19dde..094d90ba5a2 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -114,19 +114,12 @@ public class TestLucene99ScalarQuantizedVectorsFormat extends BaseKnnVectorsForm vectors.add(randomVector(dim)); } ScalarQuantizer scalarQuantizer = - confidenceInterval != null && confidenceInterval == 0f - ? ScalarQuantizer.fromVectorsAutoInterval( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), - similarityFunction, - numVectors, - (byte) bits) - : ScalarQuantizer.fromVectors( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), - confidenceInterval == null - ? Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval(dim) - : confidenceInterval, - numVectors, - (byte) bits); + Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors), + numVectors, + similarityFunction, + confidenceInterval, + (byte) bits); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java index 34af5ea3e15..0bf2a4ef6b8 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java @@ -23,6 +23,7 @@ import java.util.List; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.ScalarQuantizer; public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase { @@ -87,12 +88,12 @@ public class TestLucene99ScalarQuantizedVectorsWriter extends LuceneTestCase { for (int i = 0; i < 30; i++) { float[] vector = new float[] {i, i + 1, i + 2, i + 3}; vectors.add(vector); + if (vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT) { + VectorUtil.l2normalize(vector); + } } FloatVectorValues vectorValues = - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper( - vectors, - vectorSimilarityFunction == VectorSimilarityFunction.COSINE - || vectorSimilarityFunction == VectorSimilarityFunction.DOT_PRODUCT); + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( vectorValues, 30, vectorSimilarityFunction, confidenceInterval, bits); diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 0ee2c01aa29..48eb7ce651c 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -25,6 +25,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.tests.util.LuceneTestCase; +import org.apache.lucene.util.VectorUtil; public class TestScalarQuantizer extends LuceneTestCase { @@ -33,8 +34,16 @@ public class TestScalarQuantizer extends LuceneTestCase { int dims = random().nextInt(9) + 1; int numVecs = random().nextInt(9) + 10; float[][] floats = randomFloats(numVecs, dims); + if (function == VectorSimilarityFunction.COSINE) { + for (float[] v : floats) { + VectorUtil.l2normalize(v); + } + } for (byte bits : new byte[] {4, 7}) { FloatVectorValues floatVectorValues = fromFloats(floats); + if (function == VectorSimilarityFunction.COSINE) { + function = VectorSimilarityFunction.DOT_PRODUCT; + } ScalarQuantizer scalarQuantizer = random().nextBoolean() ? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits) @@ -63,11 +72,15 @@ public class TestScalarQuantizer extends LuceneTestCase { expectThrows( IllegalStateException.class, () -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits)); + VectorSimilarityFunction actualFunction = + function == VectorSimilarityFunction.COSINE + ? VectorSimilarityFunction.DOT_PRODUCT + : function; expectThrows( IllegalStateException.class, () -> ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, function, numVecs, bits)); + floatVectorValues, actualFunction, numVecs, bits)); } } } @@ -185,6 +198,9 @@ public class TestScalarQuantizer extends LuceneTestCase { VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; float[][] floats = randomFloats(numVecs, dims); + for (float[] v : floats) { + VectorUtil.l2normalize(v); + } FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectorsAutoInterval(