This commit is contained in:
Vigya Sharma 2024-11-12 14:40:25 -08:00
parent 9533336c91
commit 5fb88b9c7f
2 changed files with 45 additions and 29 deletions

View File

@ -22,10 +22,10 @@ import org.apache.lucene.util.ArrayUtil;
/**
* Computes similarity between two multi-vectors.
* <p>
* A multi-vector is a collection of multiple vectors that represent a single document or query.
* MultiVectorSimilarityFunction is used to determine nearest neighbors during
* indexing and search on multi-vectors.
*
* <p>A multi-vector is a collection of multiple vectors that represent a single document or query.
* MultiVectorSimilarityFunction is used to determine nearest neighbors during indexing and search
* on multi-vectors.
*/
public class MultiVectorSimilarityFunction {
@ -33,9 +33,8 @@ public class MultiVectorSimilarityFunction {
public enum Aggregation {
/**
* Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity
* found for each vector in the first multi-vector against all vectors in the second
* multi-vector.
* Sum_Max Similarity between two multi-vectors. Computes the sum of maximum similarity found
* for each vector in the first multi-vector against all vectors in the second multi-vector.
*/
SUM_MAX {
@Override
@ -103,8 +102,8 @@ public class MultiVectorSimilarityFunction {
/**
* Computes and aggregates similarity over multiple vector values.
* <p>
* Assumes all vector values in both provided multi-vectors have the same dimension. Slices
*
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices
* inner and outer float[] multi-vectors into dimension sized vector values for comparison.
*
* @param outer first multi-vector
@ -121,8 +120,8 @@ public class MultiVectorSimilarityFunction {
/**
* Computes and aggregates similarity over multiple vector values.
* <p>
* Assumes all vector values in both provided multi-vectors have the same dimension. Slices
*
* <p>Assumes all vector values in both provided multi-vectors have the same dimension. Slices
* inner and outer byte[] multi-vectors into dimension sized vector values for comparison.
*
* @param outer first multi-vector

View File

@ -11,46 +11,60 @@ public class TestMultiVectorSimilarityFunction extends LuceneTestCase {
final int dimension = 3;
final VectorSimilarityFunction vectorSim = VectorSimilarityFunction.DOT_PRODUCT;
float[][] a = new float[][] {
VectorUtil.l2normalize(new float[]{1.f, 4.f, -3.f}),
VectorUtil.l2normalize(new float[]{8.f, 3.f, -7.f})
};
float[][] b = new float[][] {
VectorUtil.l2normalize(new float[]{-5.f, 2.f, 4.f}),
VectorUtil.l2normalize(new float[]{7.f, 1.f, -3.f}),
VectorUtil.l2normalize(new float[]{-5.f, 8.f, 3.f})
};
float[][] a =
new float[][] {
VectorUtil.l2normalize(new float[] {1.f, 4.f, -3.f}),
VectorUtil.l2normalize(new float[] {8.f, 3.f, -7.f})
};
float[][] b =
new float[][] {
VectorUtil.l2normalize(new float[] {-5.f, 2.f, 4.f}),
VectorUtil.l2normalize(new float[] {7.f, 1.f, -3.f}),
VectorUtil.l2normalize(new float[] {-5.f, 8.f, 3.f})
};
float result = 0f;
float[] a0_bDot = new float[] {vectorSim.compare(a[0], b[0]), vectorSim.compare(a[0], b[1]), vectorSim.compare(a[0], b[2])};
float[] a0_bDot =
new float[] {
vectorSim.compare(a[0], b[0]),
vectorSim.compare(a[0], b[1]),
vectorSim.compare(a[0], b[2])
};
float max = Float.MIN_VALUE;
for (float k: a0_bDot) {
for (float k : a0_bDot) {
max = Float.max(max, k);
}
result += max;
float[] a1_bDot = new float[] {vectorSim.compare(a[1], b[0]), vectorSim.compare(a[1], b[1]), vectorSim.compare(a[1], b[2])};
float[] a1_bDot =
new float[] {
vectorSim.compare(a[1], b[0]),
vectorSim.compare(a[1], b[1]),
vectorSim.compare(a[1], b[2])
};
max = Float.MIN_VALUE;
for (float k: a1_bDot) {
for (float k : a1_bDot) {
max = Float.max(max, k);
}
result += max;
float[] a_Packed = new float[a.length * dimension];
int i = 0;
for (float[] v: a) {
for (float[] v : a) {
System.arraycopy(v, 0, a_Packed, i, dimension);
i += dimension;
}
float[] b_Packed = new float[b.length * dimension];
i = 0;
for (float[] v: b) {
for (float[] v : b) {
System.arraycopy(v, 0, b_Packed, i, dimension);
i += dimension;
}
MultiVectorSimilarityFunction mvSim = new MultiVectorSimilarityFunction(VectorSimilarityFunction.DOT_PRODUCT, MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
MultiVectorSimilarityFunction mvSim =
new MultiVectorSimilarityFunction(
VectorSimilarityFunction.DOT_PRODUCT,
MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
float score = mvSim.compare(a_Packed, b_Packed, dimension);
assertEquals(result, score, 0.0001f);
}
@ -59,7 +73,10 @@ public class TestMultiVectorSimilarityFunction extends LuceneTestCase {
public void testDimensionCheck() {
float[] a = {1f, 2f, 3f, 4f, 5f, 6f};
float[] b = {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f};
MultiVectorSimilarityFunction mvSim = new MultiVectorSimilarityFunction(VectorSimilarityFunction.DOT_PRODUCT, MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
MultiVectorSimilarityFunction mvSim =
new MultiVectorSimilarityFunction(
VectorSimilarityFunction.DOT_PRODUCT,
MultiVectorSimilarityFunction.Aggregation.SUM_MAX);
assertThrows(IllegalArgumentException.class, () -> mvSim.compare(a, b, 2));
}
}