Normalize written scalar quantized vectors when using cosine similarity (#12780)

### Description

When using cosine similarity, the `ScalarQuantizer` normalizes vectors when calculating quantiles and `ScalarQuantizedRandomVectorScorer` normalizes query vectors before scoring them, but `Lucene99ScalarQuantizedVectorsWriter` does not normalize the vectors prior to quantizing them when producing the quantized vectors to write to disk. This PR normalizes vectors prior to quantizing them when writing them to disk.

Recall results on my M1 with the `glove-100-angular` data set (all using `maxConn`: 16, `beamWidth` 100, `numCandidates`: 100, `k`: 10, single segment):
| Configuration | Recall | Average Query Duration |
|---------------|-------|-----------------| 
| Pre-patch no quantization | 0.78762 | 0.68 ms |
| Pre-patch with quantization | 8.999999999999717E-5 | 0.45 ms |
| Post-patch no quantization | 0.78762 | 0.70 ms |
| Post-patch with quantization | 0.66742 | 0.66 ms |
This commit is contained in:
Kevin Rosendahl 2023-11-08 11:26:48 -08:00 committed by GitHub
parent 20d5de448a
commit ddb01cacd4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 51 additions and 15 deletions

View File

@ -142,7 +142,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim]; byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
for (float[] v : fieldData.floatVectors) { for (float[] v : fieldData.floatVectors) {
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection = float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length); quantizedVectorData.writeBytes(vector, vector.length);
@ -194,8 +201,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
byte[] vector = new byte[fieldData.dim]; byte[] vector = new byte[fieldData.dim];
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
for (int ordinal : ordMap) { for (int ordinal : ordMap) {
float[] v = fieldData.floatVectors.get(ordinal); float[] v = fieldData.floatVectors.get(ordinal);
if (fieldData.normalize) {
System.arraycopy(v, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
v = copy;
}
float offsetCorrection = float offsetCorrection =
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
quantizedVectorData.writeBytes(vector, vector.length); quantizedVectorData.writeBytes(vector, vector.length);

View File

@ -38,8 +38,10 @@ import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.ScalarQuantizer; import org.apache.lucene.util.ScalarQuantizer;
import org.apache.lucene.util.VectorUtil;
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
@Override @Override
protected Codec getCodec() { protected Codec getCodec() {
return new Lucene99Codec() { return new Lucene99Codec() {
@ -57,6 +59,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
// create lucene directory with codec // create lucene directory with codec
int numVectors = 1 + random().nextInt(50); int numVectors = 1 + random().nextInt(50);
VectorSimilarityFunction similarityFunction = randomSimilarity(); VectorSimilarityFunction similarityFunction = randomSimilarity();
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
int dim = random().nextInt(64) + 1; int dim = random().nextInt(64) + 1;
List<float[]> vectors = new ArrayList<>(numVectors); List<float[]> vectors = new ArrayList<>(numVectors);
for (int i = 0; i < numVectors; i++) { for (int i = 0; i < numVectors; i++) {
@ -65,13 +68,22 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim); float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
ScalarQuantizer scalarQuantizer = ScalarQuantizer scalarQuantizer =
ScalarQuantizer.fromVectors( ScalarQuantizer.fromVectors(
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile); new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
quantile);
float[] expectedCorrections = new float[numVectors]; float[] expectedCorrections = new float[numVectors];
byte[][] expectedVectors = new byte[numVectors][]; byte[][] expectedVectors = new byte[numVectors][];
for (int i = 0; i < numVectors; i++) { for (int i = 0; i < numVectors; i++) {
float[] vector = vectors.get(i);
if (normalize) {
float[] copy = new float[vector.length];
System.arraycopy(vector, 0, copy, 0, copy.length);
VectorUtil.l2normalize(copy);
vector = copy;
}
expectedVectors[i] = new byte[dim]; expectedVectors[i] = new byte[dim];
expectedCorrections[i] = expectedCorrections[i] =
scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction); scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction);
} }
float[] randomlyReusedVector = new float[dim]; float[] randomlyReusedVector = new float[dim];

View File

@ -81,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
protected void addRandomFields(Document doc) { protected void addRandomFields(Document doc) {
switch (vectorEncoding) { switch (vectorEncoding) {
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction)); case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomVector(30), similarityFunction)); case FLOAT32 -> doc.add(
new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction));
} }
} }
@ -611,7 +612,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// assert that knn search doesn't fail on a field with all deleted docs // assert that knn search doesn't fail on a field with all deleted docs
TopDocs results = TopDocs results =
leafReader.searchNearestVectors( leafReader.searchNearestVectors(
"v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); "v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
assertEquals(0, results.scoreDocs.length); assertEquals(0, results.scoreDocs.length);
} }
} }
@ -664,7 +665,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
fieldTotals[field] += b[0]; fieldTotals[field] += b[0];
} }
case FLOAT32 -> { case FLOAT32 -> {
float[] v = randomVector(fieldDims[field]); float[] v = randomNormalizedVector(fieldDims[field]);
doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field])); doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field]));
fieldTotals[field] += v[0]; fieldTotals[field] += v[0];
} }
@ -885,7 +886,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
for (int i = 0; i < numDoc; i++) { for (int i = 0; i < numDoc; i++) {
if (random().nextInt(7) != 3) { if (random().nextInt(7) != 3) {
// usually index a vector value for a doc // usually index a vector value for a doc
values[i] = randomVector(dimension); values[i] = randomNormalizedVector(dimension);
++numValues; ++numValues;
} }
if (random().nextBoolean() && values[i] != null) { if (random().nextBoolean() && values[i] != null) {
@ -1033,7 +1034,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
float[] value; float[] value;
if (random().nextInt(7) != 3) { if (random().nextInt(7) != 3) {
// usually index a vector value for a doc // usually index a vector value for a doc
value = randomVector(dimension); value = randomNormalizedVector(dimension);
} else { } else {
value = null; value = null;
} }
@ -1061,7 +1062,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
TopDocs results = TopDocs results =
ctx.reader() ctx.reader()
.searchNearestVectors( .searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit); fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
assertEquals(visitedLimit, results.totalHits.value); assertEquals(visitedLimit, results.totalHits.value);
@ -1071,7 +1072,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
results = results =
ctx.reader() ctx.reader()
.searchNearestVectors( .searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, visitedLimit); fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
assertTrue(results.totalHits.value <= visitedLimit); assertTrue(results.totalHits.value <= visitedLimit);
} }
@ -1097,7 +1098,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
float[] value; float[] value;
if (random().nextInt(7) != 3) { if (random().nextInt(7) != 3) {
// usually index a vector value for a doc // usually index a vector value for a doc
value = randomVector(dimension); value = randomNormalizedVector(dimension);
} else { } else {
value = null; value = null;
} }
@ -1147,7 +1148,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
TopDocs results = TopDocs results =
ctx.reader() ctx.reader()
.searchNearestVectors( .searchNearestVectors(
fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE); fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE);
assertEquals(Math.min(k, size), results.scoreDocs.length); assertEquals(Math.min(k, size), results.scoreDocs.length);
for (int i = 0; i < k - 1; i++) { for (int i = 0; i < k - 1; i++) {
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
@ -1233,13 +1234,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
squareSum += v[i] * v[i]; squareSum += v[i] * v[i];
} }
} }
return v;
}
protected float[] randomNormalizedVector(int dim) {
float[] v = randomVector(dim);
VectorUtil.l2normalize(v); VectorUtil.l2normalize(v);
return v; return v;
} }
private byte[] randomVector8(int dim) { private byte[] randomVector8(int dim) {
assert dim > 0; assert dim > 0;
float[] v = randomVector(dim); float[] v = randomNormalizedVector(dim);
byte[] b = new byte[dim]; byte[] b = new byte[dim];
for (int i = 0; i < dim; i++) { for (int i = 0; i < dim; i++) {
b[i] = (byte) (v[i] * 127); b[i] = (byte) (v[i] * 127);
@ -1251,10 +1257,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
try (Directory dir = newDirectory()) { try (Directory dir = newDirectory()) {
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
Document doc = new Document(); Document doc = new Document();
doc.add(new KnnFloatVectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); doc.add(
new KnnFloatVectorField(
"v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc); w.addDocument(doc);
doc.add(new KnnFloatVectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); doc.add(
new KnnFloatVectorField(
"v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
w.addDocument(doc); w.addDocument(doc);
} }
@ -1359,7 +1369,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction)); doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
} }
case FLOAT32 -> { case FLOAT32 -> {
float[] v = randomVector(dim); float[] v = randomNormalizedVector(dim);
fieldValuesCheckSum += v[0]; fieldValuesCheckSum += v[0];
doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction)); doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction));
} }