Normalize written scalar quantized vectors when using cosine similarity (#12780)

### Description

When using cosine similarity, the `ScalarQuantizer` normalizes vectors when calculating quantiles and `ScalarQuantizedRandomVectorScorer` normalizes query vectors before scoring them, but `Lucene99ScalarQuantizedVectorsWriter` does not normalize the vectors prior to quantizing them when producing the quantized vectors to write to disk. This PR normalizes vectors prior to quantizing them when writing them to disk.

Recall results on my M1 with the `glove-100-angular` data set (all using `maxConn`: 16, `beamWidth` 100, `numCandidates`: 100, `k`: 10, single segment):
| Configuration | Recall | Average Query Duration |
|---------------|-------|-----------------| 
| Pre-patch no quantization | 0.78762 | 0.68 ms |
| Pre-patch with quantization | 8.999999999999717E-5 | 0.45 ms |
| Post-patch no quantization | 0.78762 | 0.70 ms |
| Post-patch with quantization | 0.66742 | 0.66 ms |
This commit is contained in:
Kevin Rosendahl 2023-11-08 11:26:48 -08:00 committed by GitHub
parent 20d5de448a
commit ddb01cacd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 15 deletions

View File

@ -142,7 +142,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
for (float[] v : fieldData.floatVectors) {
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length);
@ -194,8 +201,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal);
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length);

View File

@ -38,8 +38,10 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override
protected Codec getCodec() {
return new Lucene99Codec() {
@ -57,6 +59,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity();
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1;
List<float[]> vectors = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) {
@ -65,13 +68,22 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile);
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
quantile);
float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) {
float[] vector = vectors.get(i);
if (normalize) {
float[] copy = new float[vector.length];
System.arraycopy(vector, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
vector = copy;
}
expectedVectors[i] = new byte[dim];
expectedCorrections[i] =
scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction);
scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction);
}
float[] randomlyReusedVector = new float[dim];

View File

@ -81,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
protected void addRandomFields(Document doc) {
switch (vectorEncoding) {
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomVector(30), similarityFunction));
case FLOAT32 -> doc.add(
new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction));
}
}
@ -611,7 +612,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// assert that knn search doesn't fail on a field with all deleted docs
TopDocs results =
leafReader.searchNearestVectors(
"v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
"v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(0, results.scoreDocs.length);
}
}
@ -664,7 +665,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
fieldTotals[field] += b[0];
}
case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]);
float[] v = randomNormalizedVector(fieldDims[field]);
doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0];
}
@ -885,7 +886,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
values[i] = randomVector(dimension);
values[i] = randomNormalizedVector(dimension);
++numValues;
}
if (random().nextBoolean() && values[i] != null) {
@ -1033,7 +1034,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
float[] value;
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
value = randomVector(dimension);
value = randomNormalizedVector(dimension);
} else {
value = null;
}
@ -1061,7 +1062,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
TopDocs results =
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
assertEquals(visitedLimit, results.totalHits.value);
@ -1071,7 +1072,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
results =
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
assertTrue(results.totalHits.value <= visitedLimit);
}
@ -1097,7 +1098,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
float[] value;
if (random().nextInt(7) != 3) {
// usually index a vector value for a doc
value = randomVector(dimension);
value = randomNormalizedVector(dimension);
} else {
value = null;
}
@ -1147,7 +1148,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
TopDocs results =
ctx.reader()
.searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE);
fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE);
assertEquals(Math.min(k, size), results.scoreDocs.length);
for (int i = 0; i < k - 1; i++) {
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
@ -1233,13 +1234,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
squareSum += v[i] * v[i];
}
}
return v;
}
protected float[] randomNormalizedVector(int dim) {
float[] v = randomVector(dim);
VectorUtil.l2normalize(v);
return v;
}
private byte[] randomVector8(int dim) {
assert dim > 0;
float[] v = randomVector(dim);
float[] v = randomNormalizedVector(dim);
byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127);
@ -1251,10 +1257,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
doc.add(
new KnnFloatVectorField(
"v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
doc.add(new KnnFloatVectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
doc.add(
new KnnFloatVectorField(
"v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc);
}
@ -1359,7 +1369,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
}
case FLOAT32 -> {
float[] v = randomVector(dim);
float[] v = randomNormalizedVector(dim);
fieldValuesCheckSum += v[0];
doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction));
}