Add test for float vector values in FlatVectorsScorer impls (#13851)

This is a test only change that verifies the behaviour when float vector values are passed to our FlatVectorsScorer implementations. This would have caught the bug causing #13844, subsequently fixed by #13850.
This commit is contained in:
Chris Hegarty 2024-10-02 16:05:11 +01:00 committed by ChrisHegarty
parent 19ae89be1b
commit ab1b0b716e
1 changed files with 60 additions and 0 deletions

View File

@ -16,6 +16,7 @@
*/
package org.apache.lucene.internal.vectorization;
import static java.util.Locale.ROOT;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
@ -24,6 +25,8 @@ import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRO
import com.carrotsearch.randomizedtesting.generators.RandomNumbers;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
@ -39,6 +42,7 @@ import java.util.stream.IntStream;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues;
import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues;
import org.apache.lucene.index.KnnVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
@ -329,12 +333,63 @@ public class TestVectorScorer extends LuceneTestCase {
}
}
// Tests that the FlatVectorsScorer handles float vectors correctly.
public void testWithFloatValues() throws IOException {
try (Directory dir = new MMapDirectory(createTempDir("testWithFloatValues"))) {
final String fileName = "floatvalues";
try (IndexOutput out = dir.createOutput(fileName, IOContext.DEFAULT)) {
var vec = floatToByteArray(1f); // single vector, with one dimension
out.writeBytes(vec, 0, vec.length);
}
try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) {
for (int times = 0; times < TIMES; times++) {
for (var sim : List.of(COSINE, EUCLIDEAN, DOT_PRODUCT, MAXIMUM_INNER_PRODUCT)) {
var vectorValues = floatVectorValues(1, 1, in, sim);
assert vectorValues.getEncoding().byteSize == 4;
var supplier1 = DEFAULT_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
var supplier2 = MEMSEG_SCORER.getRandomVectorScorerSupplier(sim, vectorValues);
// these assertion assumes that the supplier and scorer's toString will have float
// in it, since it's based on float vectors.
assertTrue(supplier1.toString().toLowerCase(ROOT).contains("float"));
assertTrue(supplier2.toString().toLowerCase(ROOT).contains("float"));
assertTrue(supplier1.scorer(0).toString().toLowerCase(ROOT).contains("float"));
assertTrue(supplier2.scorer(0).toString().toLowerCase(ROOT).contains("float"));
float expected = supplier1.scorer(0).score(0);
assertEquals(supplier2.scorer(0).score(0), expected, DELTA);
var scorer1 = DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, new float[] {1f});
var scorer2 = MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, new float[] {1f});
assertTrue(scorer1.toString().toLowerCase(ROOT).contains("float"));
assertTrue(scorer2.toString().toLowerCase(ROOT).contains("float"));
expected = scorer1.score(0);
assertEquals(scorer2.score(0), expected, DELTA);
expectThrows(
Throwable.class,
() -> DEFAULT_SCORER.getRandomVectorScorer(sim, vectorValues, new byte[] {1}));
expectThrows(
Throwable.class,
() -> MEMSEG_SCORER.getRandomVectorScorer(sim, vectorValues, new byte[] {1}));
}
}
}
}
}
KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
throws IOException {
return new OffHeapByteVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
}
KnnVectorValues floatVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim)
throws IOException {
return new OffHeapFloatVectorValues.DenseOffHeapVectorValues(
dims, size, in.slice("floatValues", 0, in.length()), dims, MEMSEG_SCORER, sim);
}
// creates the vector based on the given ordinal, which is reproducible given the ord and dims
static byte[] vector(int ord, int dims) {
var random = new Random(Objects.hash(ord, dims));
@ -355,6 +410,11 @@ public class TestVectorScorer extends LuceneTestCase {
}
}
/** Converts a float value to a byte array. */
public static byte[] floatToByteArray(float value) {
return ByteBuffer.allocate(4).order(ByteOrder.LITTLE_ENDIAN).putFloat(value).array();
}
static int randomIntBetween(int minInclusive, int maxInclusive) {
return RandomNumbers.randomIntBetween(random(), minInclusive, maxInclusive);
}