From ddb01cacd443b9c8b4c0321c522f0f3a141e0778 Mon Sep 17 00:00:00 2001 From: Kevin Rosendahl Date: Wed, 8 Nov 2023 11:26:48 -0800 Subject: [PATCH] Normalize written scalar quantized vectors when using cosine similarity (#12780) ### Description When using cosine similarity, the `ScalarQuantizer` normalizes vectors when calculating quantiles and `ScalarQuantizedRandomVectorScorer` normalizes query vectors before scoring them, but `Lucene99ScalarQuantizedVectorsWriter` does not normalize the vectors prior to quantizing them when producing the quantized vectors to write to disk. This PR normalizes vectors prior to quantizing them when writing them to disk. Recall results on my M1 with the `glove-100-angular` data set (all using `maxConn`: 16, `beamWidth` 100, `numCandidates`: 100, `k`: 10, single segment): | Configuration | Recall | Average Query Duration | |---------------|-------|-----------------| | Pre-patch no quantization | 0.78762 | 0.68 ms | | Pre-patch with quantization | 8.999999999999717E-5 | 0.45 ms | | Post-patch no quantization | 0.78762 | 0.70 ms | | Post-patch with quantization | 0.66742 | 0.66 ms | --- .../Lucene99ScalarQuantizedVectorsWriter.java | 14 ++++++++ ...estLucene99HnswQuantizedVectorsFormat.java | 16 +++++++-- .../index/BaseKnnVectorsFormatTestCase.java | 36 ++++++++++++------- 3 files changed, 51 insertions(+), 15 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 613fa89b020..93bc6b0011a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -142,7 +142,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.dim]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; for (float[] v : fieldData.floatVectors) { + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); quantizedVectorData.writeBytes(vector, vector.length); @@ -194,8 +201,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.dim]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; for (int ordinal : ordMap) { float[] v = fieldData.floatVectors.get(ordinal); + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); quantizedVectorData.writeBytes(vector, vector.length); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 4e262ca4c9d..d3d13d5ef61 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -38,8 +38,10 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.util.ScalarQuantizer; +import org.apache.lucene.util.VectorUtil; public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + @Override protected Codec getCodec() { return new Lucene99Codec() { @@ -57,6 +59,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat // create lucene directory with codec int numVectors = 1 + random().nextInt(50); VectorSimilarityFunction similarityFunction = randomSimilarity(); + boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE; int dim = random().nextInt(64) + 1; List vectors = new ArrayList<>(numVectors); for (int i = 0; i < numVectors; i++) { @@ -65,13 +68,22 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim); ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile); + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), + quantile); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { + float[] vector = vectors.get(i); + if (normalize) { + float[] copy = new float[vector.length]; + System.arraycopy(vector, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + vector = copy; + } + expectedVectors[i] = new byte[dim]; expectedCorrections[i] = - scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction); + scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction); } float[] randomlyReusedVector = new float[dim]; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 955c6ae7127..d6883793cfb 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -81,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe protected void addRandomFields(Document doc) { switch (vectorEncoding) { case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction)); - case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomVector(30), similarityFunction)); + case FLOAT32 -> doc.add( + new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction)); } } @@ -611,7 +612,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); + "v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -664,7 +665,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe fieldTotals[field] += b[0]; } case FLOAT32 -> { - float[] v = randomVector(fieldDims[field]); + float[] v = randomNormalizedVector(fieldDims[field]); doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field])); fieldTotals[field] += v[0]; } @@ -885,7 +886,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe for (int i = 0; i < numDoc; i++) { if (random().nextInt(7) != 3) { // usually index a vector value for a doc - values[i] = randomVector(dimension); + values[i] = randomNormalizedVector(dimension); ++numValues; } if (random().nextBoolean() && values[i] != null) { @@ -1033,7 +1034,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe float[] value; if (random().nextInt(7) != 3) { // usually index a vector value for a doc - value = randomVector(dimension); + value = randomNormalizedVector(dimension); } else { value = null; } @@ -1061,7 +1062,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation); assertEquals(visitedLimit, results.totalHits.value); @@ -1071,7 +1072,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation); assertTrue(results.totalHits.value <= visitedLimit); } @@ -1097,7 +1098,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe float[] value; if (random().nextInt(7) != 3) { // usually index a vector value for a doc - value = randomVector(dimension); + value = randomNormalizedVector(dimension); } else { value = null; } @@ -1147,7 +1148,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE); + fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE); assertEquals(Math.min(k, size), results.scoreDocs.length); for (int i = 0; i < k - 1; i++) { assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); @@ -1233,13 +1234,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe squareSum += v[i] * v[i]; } } + return v; + } + + protected float[] randomNormalizedVector(int dim) { + float[] v = randomVector(dim); VectorUtil.l2normalize(v); return v; } private byte[] randomVector8(int dim) { assert dim > 0; - float[] v = randomVector(dim); + float[] v = randomNormalizedVector(dim); byte[] b = new byte[dim]; for (int i = 0; i < dim; i++) { b[i] = (byte) (v[i] * 127); @@ -1251,10 +1257,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new KnnFloatVectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); + doc.add( + new KnnFloatVectorField( + "v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); - doc.add(new KnnFloatVectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); + doc.add( + new KnnFloatVectorField( + "v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); } @@ -1359,7 +1369,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction)); } case FLOAT32 -> { - float[] v = randomVector(dim); + float[] v = randomNormalizedVector(dim); fieldValuesCheckSum += v[0]; doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction)); }