diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 613fa89b020..93bc6b0011a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -142,7 +142,14 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.dim]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; for (float[] v : fieldData.floatVectors) { + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); quantizedVectorData.writeBytes(vector, vector.length); @@ -194,8 +201,15 @@ public final class Lucene99ScalarQuantizedVectorsWriter implements Accountable { ScalarQuantizer scalarQuantizer = fieldData.createQuantizer(); byte[] vector = new byte[fieldData.dim]; final ByteBuffer offsetBuffer = ByteBuffer.allocate(Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + float[] copy = fieldData.normalize ? new float[fieldData.dim] : null; for (int ordinal : ordMap) { float[] v = fieldData.floatVectors.get(ordinal); + if (fieldData.normalize) { + System.arraycopy(v, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + v = copy; + } + float offsetCorrection = scalarQuantizer.quantize(v, vector, fieldData.vectorSimilarityFunction); quantizedVectorData.writeBytes(vector, vector.length); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 4e262ca4c9d..d3d13d5ef61 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -38,8 +38,10 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; import org.apache.lucene.util.ScalarQuantizer; +import org.apache.lucene.util.VectorUtil; public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormatTestCase { + @Override protected Codec getCodec() { return new Lucene99Codec() { @@ -57,6 +59,7 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat // create lucene directory with codec int numVectors = 1 + random().nextInt(50); VectorSimilarityFunction similarityFunction = randomSimilarity(); + boolean normalize = similarityFunction == VectorSimilarityFunction.COSINE; int dim = random().nextInt(64) + 1; List vectors = new ArrayList<>(numVectors); for (int i = 0; i < numVectors; i++) { @@ -65,13 +68,22 @@ public class TestLucene99HnswQuantizedVectorsFormat extends BaseKnnVectorsFormat float quantile = Lucene99ScalarQuantizedVectorsFormat.calculateDefaultQuantile(dim); ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors( - new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, false), quantile); + new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors, normalize), + quantile); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { + float[] vector = vectors.get(i); + if (normalize) { + float[] copy = new float[vector.length]; + System.arraycopy(vector, 0, copy, 0, copy.length); + VectorUtil.l2normalize(copy); + vector = copy; + } + expectedVectors[i] = new byte[dim]; expectedCorrections[i] = - scalarQuantizer.quantize(vectors.get(i), expectedVectors[i], similarityFunction); + scalarQuantizer.quantize(vector, expectedVectors[i], similarityFunction); } float[] randomlyReusedVector = new float[dim]; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 955c6ae7127..d6883793cfb 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -81,7 +81,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe protected void addRandomFields(Document doc) { switch (vectorEncoding) { case BYTE -> doc.add(new KnnByteVectorField("v2", randomVector8(30), similarityFunction)); - case FLOAT32 -> doc.add(new KnnFloatVectorField("v2", randomVector(30), similarityFunction)); + case FLOAT32 -> doc.add( + new KnnFloatVectorField("v2", randomNormalizedVector(30), similarityFunction)); } } @@ -611,7 +612,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", randomVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); + "v", randomNormalizedVector(3), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -664,7 +665,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe fieldTotals[field] += b[0]; } case FLOAT32 -> { - float[] v = randomVector(fieldDims[field]); + float[] v = randomNormalizedVector(fieldDims[field]); doc.add(new KnnFloatVectorField(fieldName, v, fieldSimilarityFunctions[field])); fieldTotals[field] += v[0]; } @@ -885,7 +886,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe for (int i = 0; i < numDoc; i++) { if (random().nextInt(7) != 3) { // usually index a vector value for a doc - values[i] = randomVector(dimension); + values[i] = randomNormalizedVector(dimension); ++numValues; } if (random().nextBoolean() && values[i] != null) { @@ -1033,7 +1034,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe float[] value; if (random().nextInt(7) != 3) { // usually index a vector value for a doc - value = randomVector(dimension); + value = randomNormalizedVector(dimension); } else { value = null; } @@ -1061,7 +1062,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation); assertEquals(visitedLimit, results.totalHits.value); @@ -1071,7 +1072,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation); assertTrue(results.totalHits.value <= visitedLimit); } @@ -1097,7 +1098,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe float[] value; if (random().nextInt(7) != 3) { // usually index a vector value for a doc - value = randomVector(dimension); + value = randomNormalizedVector(dimension); } else { value = null; } @@ -1147,7 +1148,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomVector(dimension), k, liveDocs, Integer.MAX_VALUE); + fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE); assertEquals(Math.min(k, size), results.scoreDocs.length); for (int i = 0; i < k - 1; i++) { assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); @@ -1233,13 +1234,18 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe squareSum += v[i] * v[i]; } } + return v; + } + + protected float[] randomNormalizedVector(int dim) { + float[] v = randomVector(dim); VectorUtil.l2normalize(v); return v; } private byte[] randomVector8(int dim) { assert dim > 0; - float[] v = randomVector(dim); + float[] v = randomNormalizedVector(dim); byte[] b = new byte[dim]; for (int i = 0; i < dim; i++) { b[i] = (byte) (v[i] * 127); @@ -1251,10 +1257,14 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe try (Directory dir = newDirectory()) { try (IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { Document doc = new Document(); - doc.add(new KnnFloatVectorField("v1", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); + doc.add( + new KnnFloatVectorField( + "v1", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); - doc.add(new KnnFloatVectorField("v2", randomVector(3), VectorSimilarityFunction.EUCLIDEAN)); + doc.add( + new KnnFloatVectorField( + "v2", randomNormalizedVector(3), VectorSimilarityFunction.EUCLIDEAN)); w.addDocument(doc); } @@ -1359,7 +1369,7 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe doc.add(new KnnByteVectorField("knn_vector", b, similarityFunction)); } case FLOAT32 -> { - float[] v = randomVector(dim); + float[] v = randomNormalizedVector(dim); fieldValuesCheckSum += v[0]; doc.add(new KnnFloatVectorField("knn_vector", v, similarityFunction)); }