add test for mvSim

This commit is contained in:
Vigya Sharma 2024-11-12 13:47:49 -08:00
parent cf97155d93
commit 9533336c91
1 changed files with 51 additions and 5 deletions

View File

@ -1,6 +1,7 @@
package org.apache.lucene.index; package org.apache.lucene.index;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.util.VectorUtil;
import org.junit.Test; import org.junit.Test;
public class TestMultiVectorSimilarityFunction extends LuceneTestCase { public class TestMultiVectorSimilarityFunction extends LuceneTestCase {
@ -8,12 +9,57 @@ public class TestMultiVectorSimilarityFunction extends LuceneTestCase {
@Test @Test
public void testSumMaxWithDotProduct() { public void testSumMaxWithDotProduct() {
final int dimension = 3; final int dimension = 3;
float[] a = new float[] {1.f, 4.f, -3.f, 8.f, 3.f, -7.f, -2.f, 1.f, 9.f}; final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT;
float[] b = new float[] {-5.f, 2.f, 4.f, 7.f, 1.f, -3.f, -5.f, 8.f, 3.f};
MultiVectorSimilarityFunction mvsf = new MultiVectorSimilarityFunction(VectorSimilarityFunction.DOT_PRODUCT, MultiVectorSimilarityFunction.Aggregation.SUM_MAX); float[][] a = new float[][] {
float score = mvsf.compare(a, b, dimension); VectorUtil.l2normalize(new float[]{1.f, 4.f, -3.f}),
assertEquals(95f, score, 0.00001f); VectorUtil.l2normalize(new float[]{8.f, 3.f, -7.f})
};
float[][] b = new float[][] {
VectorUtil.l2normalize(new float[]{-5.f, 2.f, 4.f}),
VectorUtil.l2normalize(new float[]{7.f, 1.f, -3.f}),
VectorUtil.l2normalize(new float[]{-5.f, 8.f, 3.f})
};
float result = 0f;
float[] a0_bDot = new float[] {vectorSim.compare(a[0], b[0]), vectorSim.compare(a[0], b[1]), vectorSim.compare(a[0], b[2])};
float max = Float.MIN_VALUE;
for (float k: a0_bDot) {
max = Float.max(max, k);
}
result += max;
float[] a1_bDot = new float[] {vectorSim.compare(a[1], b[0]), vectorSim.compare(a[1], b[1]), vectorSim.compare(a[1], b[2])};
max = Float.MIN_VALUE;
for (float k: a1_bDot) {
max = Float.max(max, k);
}
result += max;
float[] a_Packed = new float[a.length * dimension];
int i = 0;
for (float[] v: a) {
System.arraycopy(v, 0, a_Packed, i, dimension);
i += dimension;
}
float[] b_Packed = new float[b.length * dimension];
i = 0;
for (float[] v: b) {
System.arraycopy(v, 0, b_Packed, i, dimension);
i += dimension;
}
MultiVectorSimilarityFunction mvSim = new MultiVectorSimilarityFunction(VectorSimilarityFunction.DOT_PRODUCT, MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
float score = mvSim.compare(a_Packed, b_Packed, dimension);
assertEquals(result, score, 0.0001f);
} }
@Test
public void testDimensionCheck() {
float[] a = {1f, 2f, 3f, 4f, 5f, 6f};
float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f};
MultiVectorSimilarityFunction mvSim = new MultiVectorSimilarityFunction(VectorSimilarityFunction.DOT_PRODUCT, MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2));
}
} }