mirror of https://github.com/apache/lucene.git
Normalize written scalar quantized vectors when using cosine similarity (#12780)
### Description When using cosine similarity, the `ScalarQuantizer` normalizes vectors when calculating quantiles and `ScalarQuantizedRandomVectorScorer` normalizes query vectors before scoring them, but `Lucene99ScalarQuantizedVectorsWriter` does not normalize the vectors prior to quantizing them when producing the quantized vectors to write to disk. This PR normalizes vectors prior to quantizing them when writing them to disk. Recall results on my M1 with the `glove-100-angular` data set (all using `maxConn`: 16, `beamWidth` 100, `numCandidates`: 100, `k`: 10, single segment): | Configuration | Recall | Average Query Duration | |---------------|-------|-----------------| | Pre-patch no quantization | 0.78762 | 0.68 ms | | Pre-patch with quantization | 8.999999999999717E-5 | 0.45 ms | | Post-patch no quantization | 0.78762 | 0.70 ms | | Post-patch with quantization | 0.66742 | 0.66 ms |
This commit is contained in:
parent
20d5de448a
commit
ddb01cacd4
|
@ -142,7 +142,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||||
byte[] vector = new byte[fieldData.dim];
|
byte[] vector = new byte[fieldData.dim];
|
||||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
|
||||||
for (float[] v : fieldData.floatVectors) {
|
for (float[] v : fieldData.floatVectors) {
|
||||||
|
if (fieldData.normalize) {
|
||||||
|
System.arraycopy(v, 0, copy, 0, copy.length);
|
||||||
|
VectorUtil.l2normalize(copy);
|
||||||
|
v = copy;
|
||||||
|
}
|
||||||
|
|
||||||
float offsetCorrection =
|
float offsetCorrection =
|
||||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
||||||
quantizedVectorData.writeBytes(vector, vector.length);
|
quantizedVectorData.writeBytes(vector, vector.length);
|
||||||
|
@ -194,8 +201,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable {
|
||||||
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
ScalarQuantizer scalarQuantizer = fieldData.createQuantizer();
|
||||||
byte[] vector = new byte[fieldData.dim];
|
byte[] vector = new byte[fieldData.dim];
|
||||||
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
|
||||||
|
float[] copy = fieldData.normalize ? new float[fieldData.dim] : null;
|
||||||
for (int ordinal : ordMap) {
|
for (int ordinal : ordMap) {
|
||||||
float[] v = fieldData.floatVectors.get(ordinal);
|
float[] v = fieldData.floatVectors.get(ordinal);
|
||||||
|
if (fieldData.normalize) {
|
||||||
|
System.arraycopy(v, 0, copy, 0, copy.length);
|
||||||
|
VectorUtil.l2normalize(copy);
|
||||||
|
v = copy;
|
||||||
|
}
|
||||||
|
|
||||||
float offsetCorrection =
|
float offsetCorrection =
|
||||||
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction);
|
||||||
quantizedVectorData.writeBytes(vector, vector.length);
|
quantizedVectorData.writeBytes(vector, vector.length);
|
||||||
|
|
|
@ -38,8 +38,10 @@ import org.apache.lucene.index.VectorSimilarityFunction;
|
||||||
import org.apache.lucene.store.Directory;
|
import org.apache.lucene.store.Directory;
|
||||||
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
|
||||||
import org.apache.lucene.util.ScalarQuantizer;
|
import org.apache.lucene.util.ScalarQuantizer;
|
||||||
|
import org.apache.lucene.util.VectorUtil;
|
||||||
|
|
||||||
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase {
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
protected Codec getCodec() {
|
protected Codec getCodec() {
|
||||||
return new Lucene99Codec() {
|
return new Lucene99Codec() {
|
||||||
|
@ -57,6 +59,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
||||||
// create lucene directory with codec
|
// create lucene directory with codec
|
||||||
int numVectors = 1 + random().nextInt(50);
|
int numVectors = 1 + random().nextInt(50);
|
||||||
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
VectorSimilarityFunction similarityFunction = randomSimilarity();
|
||||||
|
boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE;
|
||||||
int dim = random().nextInt(64) + 1;
|
int dim = random().nextInt(64) + 1;
|
||||||
List<float[]> vectors = new ArrayList<>(numVectors);
|
List<float[]> vectors = new ArrayList<>(numVectors);
|
||||||
for (int i = 0; i < numVectors; i++) {
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
@ -65,13 +68,22 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
|
||||||
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
|
float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim);
|
||||||
ScalarQuantizer scalarQuantizer =
|
ScalarQuantizer scalarQuantizer =
|
||||||
ScalarQuantizer.fromVectors(
|
ScalarQuantizer.fromVectors(
|
||||||
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile);
|
new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize),
|
||||||
|
quantile);
|
||||||
float[] expectedCorrections = new float[numVectors];
|
float[] expectedCorrections = new float[numVectors];
|
||||||
byte[][] expectedVectors = new byte[numVectors][];
|
byte[][] expectedVectors = new byte[numVectors][];
|
||||||
for (int i = 0; i < numVectors; i++) {
|
for (int i = 0; i < numVectors; i++) {
|
||||||
|
float[] vector = vectors.get(i);
|
||||||
|
if (normalize) {
|
||||||
|
float[] copy = new float[vector.length];
|
||||||
|
System.arraycopy(vector, 0, copy, 0, copy.length);
|
||||||
|
VectorUtil.l2normalize(copy);
|
||||||
|
vector = copy;
|
||||||
|
}
|
||||||
|
|
||||||
expectedVectors[i] = new byte[dim];
|
expectedVectors[i] = new byte[dim];
|
||||||
expectedCorrections[i] =
|
expectedCorrections[i] =
|
||||||
scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction);
|
scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction);
|
||||||
}
|
}
|
||||||
float[] randomlyReusedVector = new float[dim];
|
float[] randomlyReusedVector = new float[dim];
|
||||||
|
|
||||||
|
|
|
@ -81,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
protected void addRandomFields(Document doc) {
|
protected void addRandomFields(Document doc) {
|
||||||
switch (vectorEncoding) {
|
switch (vectorEncoding) {
|
||||||
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
|
case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction));
|
||||||
case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomVector(30), similarityFunction));
|
case FLOAT32 -> doc.add(
|
||||||
|
new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -611,7 +612,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
// assert that knn search doesn't fail on a field with all deleted docs
|
// assert that knn search doesn't fail on a field with all deleted docs
|
||||||
TopDocs results =
|
TopDocs results =
|
||||||
leafReader.searchNearestVectors(
|
leafReader.searchNearestVectors(
|
||||||
"v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
|
"v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE);
|
||||||
assertEquals(0, results.scoreDocs.length);
|
assertEquals(0, results.scoreDocs.length);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -664,7 +665,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
fieldTotals[field] += b[0];
|
fieldTotals[field] += b[0];
|
||||||
}
|
}
|
||||||
case FLOAT32 -> {
|
case FLOAT32 -> {
|
||||||
float[] v = randomVector(fieldDims[field]);
|
float[] v = randomNormalizedVector(fieldDims[field]);
|
||||||
doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field]));
|
doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field]));
|
||||||
fieldTotals[field] += v[0];
|
fieldTotals[field] += v[0];
|
||||||
}
|
}
|
||||||
|
@ -885,7 +886,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
for (int i = 0; i < numDoc; i++) {
|
for (int i = 0; i < numDoc; i++) {
|
||||||
if (random().nextInt(7) != 3) {
|
if (random().nextInt(7) != 3) {
|
||||||
// usually index a vector value for a doc
|
// usually index a vector value for a doc
|
||||||
values[i] = randomVector(dimension);
|
values[i] = randomNormalizedVector(dimension);
|
||||||
++numValues;
|
++numValues;
|
||||||
}
|
}
|
||||||
if (random().nextBoolean() && values[i] != null) {
|
if (random().nextBoolean() && values[i] != null) {
|
||||||
|
@ -1033,7 +1034,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
float[] value;
|
float[] value;
|
||||||
if (random().nextInt(7) != 3) {
|
if (random().nextInt(7) != 3) {
|
||||||
// usually index a vector value for a doc
|
// usually index a vector value for a doc
|
||||||
value = randomVector(dimension);
|
value = randomNormalizedVector(dimension);
|
||||||
} else {
|
} else {
|
||||||
value = null;
|
value = null;
|
||||||
}
|
}
|
||||||
|
@ -1061,7 +1062,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
TopDocs results =
|
TopDocs results =
|
||||||
ctx.reader()
|
ctx.reader()
|
||||||
.searchNearestVectors(
|
.searchNearestVectors(
|
||||||
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
|
||||||
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
|
assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation);
|
||||||
assertEquals(visitedLimit, results.totalHits.value);
|
assertEquals(visitedLimit, results.totalHits.value);
|
||||||
|
|
||||||
|
@ -1071,7 +1072,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
results =
|
results =
|
||||||
ctx.reader()
|
ctx.reader()
|
||||||
.searchNearestVectors(
|
.searchNearestVectors(
|
||||||
fieldName, randomVector(dimension), k, liveDocs, visitedLimit);
|
fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit);
|
||||||
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
|
assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation);
|
||||||
assertTrue(results.totalHits.value <= visitedLimit);
|
assertTrue(results.totalHits.value <= visitedLimit);
|
||||||
}
|
}
|
||||||
|
@ -1097,7 +1098,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
float[] value;
|
float[] value;
|
||||||
if (random().nextInt(7) != 3) {
|
if (random().nextInt(7) != 3) {
|
||||||
// usually index a vector value for a doc
|
// usually index a vector value for a doc
|
||||||
value = randomVector(dimension);
|
value = randomNormalizedVector(dimension);
|
||||||
} else {
|
} else {
|
||||||
value = null;
|
value = null;
|
||||||
}
|
}
|
||||||
|
@ -1147,7 +1148,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
TopDocs results =
|
TopDocs results =
|
||||||
ctx.reader()
|
ctx.reader()
|
||||||
.searchNearestVectors(
|
.searchNearestVectors(
|
||||||
fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE);
|
fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE);
|
||||||
assertEquals(Math.min(k, size), results.scoreDocs.length);
|
assertEquals(Math.min(k, size), results.scoreDocs.length);
|
||||||
for (int i = 0; i < k - 1; i++) {
|
for (int i = 0; i < k - 1; i++) {
|
||||||
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
|
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
|
||||||
|
@ -1233,13 +1234,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
squareSum += v[i] * v[i];
|
squareSum += v[i] * v[i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return v;
|
||||||
|
}
|
||||||
|
|
||||||
|
protected float[] randomNormalizedVector(int dim) {
|
||||||
|
float[] v = randomVector(dim);
|
||||||
VectorUtil.l2normalize(v);
|
VectorUtil.l2normalize(v);
|
||||||
return v;
|
return v;
|
||||||
}
|
}
|
||||||
|
|
||||||
private byte[] randomVector8(int dim) {
|
private byte[] randomVector8(int dim) {
|
||||||
assert dim > 0;
|
assert dim > 0;
|
||||||
float[] v = randomVector(dim);
|
float[] v = randomNormalizedVector(dim);
|
||||||
byte[] b = new byte[dim];
|
byte[] b = new byte[dim];
|
||||||
for (int i = 0; i < dim; i++) {
|
for (int i = 0; i < dim; i++) {
|
||||||
b[i] = (byte) (v[i] * 127);
|
b[i] = (byte) (v[i] * 127);
|
||||||
|
@ -1251,10 +1257,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
try (Directory dir = newDirectory()) {
|
try (Directory dir = newDirectory()) {
|
||||||
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||||
Document doc = new Document();
|
Document doc = new Document();
|
||||||
doc.add(new KnnFloatVectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
doc.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
||||||
w.addDocument(doc);
|
w.addDocument(doc);
|
||||||
|
|
||||||
doc.add(new KnnFloatVectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
doc.add(
|
||||||
|
new KnnFloatVectorField(
|
||||||
|
"v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN));
|
||||||
w.addDocument(doc);
|
w.addDocument(doc);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1359,7 +1369,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
||||||
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
|
doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction));
|
||||||
}
|
}
|
||||||
case FLOAT32 -> {
|
case FLOAT32 -> {
|
||||||
float[] v = randomVector(dim);
|
float[] v = randomNormalizedVector(dim);
|
||||||
fieldValuesCheckSum += v[0];
|
fieldValuesCheckSum += v[0];
|
||||||
doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction));
|
doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction));
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue