Ensure negative scores aren not returned from scalar quantization scorer (#13356)

Depending on how we quantize and then scale, we can edge down below 0 for dotproduct scores.

This is exceptionally rare, I have only seen it in extreme circumstances in tests (with random data and low dimensionality).
This commit is contained in:
Benjamin Trent 2024-05-13 11:00:04 -04:00 committed by GitHub
parent 8c738ba010
commit f10748cee7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 4 deletions

View File

@ -361,6 +361,8 @@ Bug Fixes
* GITHUB#12966: Aggregation facets no longer assume that aggregation values are positive. (Stefan Vodita)
* GITHUB#13356: Ensure negative scores are not returned from scalar quantization scorer. (Ben Trent)
Build
---------------------

View File

@ -100,11 +100,10 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
return switch (sim) {
case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes);
case COSINE, DOT_PRODUCT -> dotProductFactory(
targetBytes, offsetCorrection, sim, constMultiplier, values, f -> (1 + f) / 2);
targetBytes, offsetCorrection, constMultiplier, values, f -> Math.max((1 + f) / 2, 0));
case MAXIMUM_INNER_PRODUCT -> dotProductFactory(
targetBytes,
offsetCorrection,
sim,
constMultiplier,
values,
VectorUtil::scaleMaxInnerProductScore);
@ -114,7 +113,6 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory(
byte[] targetBytes,
float offsetCorrection,
VectorSimilarityFunction sim,
float constMultiplier,
RandomAccessQuantizedByteVectorValues values,
FloatToFloatFunction scoreAdjustmentFunction) {
@ -179,6 +177,8 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.dotProduct(storedVector, targetBytes);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
@ -216,6 +216,8 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
values.getSlice().readBytes(compressedVector, 0, compressedVector.length);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProductPacked(targetBytes, compressedVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}
@ -247,6 +249,8 @@ public class Lucene99ScalarQuantizedVectorScorer implements FlatVectorsScorer {
byte[] storedVector = values.vectorValue(vectorOrdinal);
float vectorOffset = values.getScoreCorrectionConstant(vectorOrdinal);
int dotProduct = VectorUtil.int4DotProduct(storedVector, targetBytes);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + offsetCorrection + vectorOffset;
return scoreAdjustmentFunction.apply(adjustedDistance);
}

View File

@ -80,8 +80,10 @@ public interface ScalarQuantizedVectorSimilarity {
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return (1 + adjustedDistance) / 2;
return Math.max((1 + adjustedDistance) / 2, 0);
}
}
@ -99,6 +101,8 @@ public interface ScalarQuantizedVectorSimilarity {
public float score(
byte[] queryVector, float queryOffset, byte[] storedVector, float vectorOffset) {
int dotProduct = comparator.compare(storedVector, queryVector);
// For the current implementation of scalar quantization, all dotproducts should be >= 0;
assert dotProduct >= 0;
float adjustedDistance = dotProduct * constMultiplier + queryOffset + vectorOffset;
return scaleMaxInnerProductScore(adjustedDistance);
}

View File

@ -36,6 +36,9 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.NoMergePolicy;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopKnnCollector;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase;
import org.apache.lucene.util.SameThreadExecutorService;
@ -77,6 +80,41 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat
};
}
public void testQuantizationScoringEdgeCase() throws Exception {
float[][] vectors = new float[][] {{0.6f, 0.8f}, {0.8f, 0.6f}, {-0.6f, -0.8f}};
try (Directory dir = newDirectory();
IndexWriter w =
new IndexWriter(
dir,
newIndexWriterConfig()
.setCodec(
new Lucene99Codec() {
@Override
public KnnVectorsFormat getKnnVectorsFormatForField(String field) {
return new Lucene99HnswScalarQuantizedVectorsFormat(
16, 100, 1, (byte) 7, false, 0.9f, null);
}
}))) {
for (float[] vector : vectors) {
Document doc = new Document();
doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.DOT_PRODUCT));
w.addDocument(doc);
w.commit();
}
w.forceMerge(1);
try (IndexReader reader = DirectoryReader.open(w)) {
LeafReader r = getOnlyLeafReader(reader);
TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE);
r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null);
TopDocs topDocs = topKnnCollector.topDocs();
assertEquals(3, topDocs.totalHits.value);
for (ScoreDoc scoreDoc : topDocs.scoreDocs) {
assertTrue(scoreDoc.score >= 0f);
}
}
}
}
public void testQuantizedVectorsWriteAndRead() throws Exception {
// create lucene directory with codec
int numVectors = 1 + random().nextInt(50);

View File

@ -17,7 +17,12 @@
package org.apache.lucene.codecs.lucene99;
import static org.apache.lucene.codecs.lucene99.OffHeapQuantizedByteVectorValues.compressBytes;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.apache.lucene.codecs.Codec;
import org.apache.lucene.codecs.KnnVectorsFormat;
import org.apache.lucene.codecs.KnnVectorsReader;
@ -32,9 +37,14 @@ import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.IOContext;
import org.apache.lucene.store.IndexInput;
import org.apache.lucene.store.IndexOutput;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.RandomVectorScorer;
import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues;
import org.apache.lucene.util.quantization.ScalarQuantizer;
public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
@ -54,6 +64,95 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
};
}
public void testNonZeroScores() throws IOException {
for (int bits : new int[] {4, 7}) {
for (boolean compress : new boolean[] {true, false}) {
vectorNonZeroScoringTest(bits, compress);
}
}
}
private void vectorNonZeroScoringTest(int bits, boolean compress) throws IOException {
try (Directory dir = newDirectory()) {
// keep vecs `0` so dot product is `0`
byte[] vec1 = new byte[32];
byte[] vec2 = new byte[32];
if (compress && bits == 4) {
byte[] vec1Compressed = new byte[16];
byte[] vec2Compressed = new byte[16];
compressBytes(vec1, vec1Compressed);
compressBytes(vec2, vec2Compressed);
vec1 = vec1Compressed;
vec2 = vec2Compressed;
}
String fileName = getTestName() + "-32";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
// large negative offset to override any query score correction and
// ensure negative values that need to be snapped to `0`
var negativeOffset = floatToByteArray(-50f);
byte[] bytes = concat(vec1, negativeOffset, vec2, negativeOffset);
out.writeBytes(bytes, 0, bytes.length);
}
ScalarQuantizer scalarQuantizer = new ScalarQuantizer(0.1f, 0.9f, (byte) bits);
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
Lucene99ScalarQuantizedVectorScorer scorer =
new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer());
RandomAccessQuantizedByteVectorValues values =
new RandomAccessQuantizedByteVectorValues() {
@Override
public int dimension() {
return 32;
}
@Override
public int getVectorByteLength() {
return compress && bits == 4 ? 16 : 32;
}
@Override
public int size() {
return 2;
}
@Override
public byte[] vectorValue(int ord) {
return new byte[32];
}
@Override
public float getScoreCorrectionConstant(int ord) {
return -50;
}
@Override
public RandomAccessQuantizedByteVectorValues copy() throws IOException {
return this;
}
@Override
public IndexInput getSlice() {
return in;
}
@Override
public ScalarQuantizer getScalarQuantizer() {
return scalarQuantizer;
}
};
float[] queryVector = new float[32];
for (int i = 0; i < 32; i++) {
queryVector[i] = i * 0.1f;
}
for (VectorSimilarityFunction function : VectorSimilarityFunction.values()) {
RandomVectorScorer randomScorer =
scorer.getRandomVectorScorer(function, values, queryVector);
assertTrue(randomScorer.score(0) >= 0f);
assertTrue(randomScorer.score(1) >= 0f);
}
}
}
}
public void testScoringCompressedInt4() throws Exception {
vectorScoringTest(4, true);
}
@ -152,4 +251,17 @@ public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase {
writer.forceMerge(1);
}
}
private static byte[] floatToByteArray(float value) {
return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array();
}
private static byte[] concat(byte[]... arrays) throws IOException {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
for (var ba : arrays) {
baos.write(ba);
}
return baos.toByteArray();
}
}
}

View File

@ -30,6 +30,26 @@ import org.apache.lucene.util.VectorUtil;
public class TestScalarQuantizedVectorSimilarity extends LuceneTestCase {
public void testNonZeroScores() {
byte[][] quantized = new byte[2][32];
for (VectorSimilarityFunction similarityFunction : VectorSimilarityFunction.values()) {
float multiplier = random().nextFloat();
if (random().nextBoolean()) {
multiplier = -multiplier;
}
for (byte bits : new byte[] {4, 7}) {
ScalarQuantizedVectorSimilarity quantizedSimilarity =
ScalarQuantizedVectorSimilarity.fromVectorSimilarity(
similarityFunction, multiplier, bits);
float negativeOffsetA = -(random().nextFloat() * (random().nextInt(10) + 1));
float negativeOffsetB = -(random().nextFloat() * (random().nextInt(10) + 1));
float score =
quantizedSimilarity.score(quantized[0], negativeOffsetA, quantized[1], negativeOffsetB);
assertTrue(score >= 0);
}
}
}
public void testToEuclidean() throws IOException {
int dims = 128;
int numVecs = 100;