mirror of https://github.com/apache/lucene.git
LUCENE-10146: Add VectorSimilarityFunction.COSINE (#366)
This PR adds support for using cosine similarity with kNN vector fields. It takes a simple approach and doesn't attempt optimizations like normalizing the query vector in advance, or performing loop unrolling. The thinking is that users who prioritize efficiency can normalize all vectors in advance and use `VectorSimilarityFunction.DOT_PRODUCT`.
This commit is contained in:
parent
ed69f6080f
commit
f4861159c3
|
@ -99,11 +99,12 @@ import org.apache.lucene.store.IndexOutput;
|
|||
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
|
||||
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
|
||||
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
|
||||
* <li>VectorDistFunction: a byte containing distance function used for similarity calculation.
|
||||
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
|
||||
* calculation.
|
||||
* <ul>
|
||||
* <li>0: no distance function is defined for this field.
|
||||
* <li>1: EUCLIDEAN_HNSW distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
|
||||
* <li>2: DOT_PRODUCT_HNSW score. ({@link VectorSimilarityFunction#DOT_PRODUCT})
|
||||
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
|
||||
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
|
||||
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
|
||||
* </ul>
|
||||
* </ul>
|
||||
*
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.lucene.document;
|
|||
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
|
||||
|
@ -73,8 +74,8 @@ public class KnnVectorField extends Field {
|
|||
/**
|
||||
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
|
||||
* no value. Vectors of a single field share the same dimension and similarity function. Note that
|
||||
* some strategies (notably dot-product) require values to be unit-length, which can be enforced
|
||||
* using VectorUtil.l2Normalize(float[]).
|
||||
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be
|
||||
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
|
||||
*
|
||||
* @param name field name
|
||||
* @param vector value
|
||||
|
|
|
@ -16,8 +16,7 @@
|
|||
*/
|
||||
package org.apache.lucene.index;
|
||||
|
||||
import static org.apache.lucene.util.VectorUtil.dotProduct;
|
||||
import static org.apache.lucene.util.VectorUtil.squareDistance;
|
||||
import static org.apache.lucene.util.VectorUtil.*;
|
||||
|
||||
/**
|
||||
* Vector similarity function; used in search to return top K most similar vectors to a target
|
||||
|
@ -39,13 +38,31 @@ public enum VectorSimilarityFunction {
|
|||
}
|
||||
},
|
||||
|
||||
/** Dot product */
|
||||
/**
|
||||
* Dot product. NOTE: this similarity is intended as an optimized way to perform cosine
|
||||
* similarity. In order to use it, all vectors must be of unit length, including both document and
|
||||
* query vectors. Using dot product with vectors that are not unit length can result in errors or
|
||||
* poor search results.
|
||||
*/
|
||||
DOT_PRODUCT {
|
||||
@Override
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return dotProduct(v1, v2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float convertToScore(float similarity) {
|
||||
return (1 + similarity) / 2;
|
||||
}
|
||||
},
|
||||
|
||||
/** Cosine similarity */
|
||||
COSINE {
|
||||
@Override
|
||||
public float compare(float[] v1, float[] v2) {
|
||||
return cosine(v1, v2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public float convertToScore(float similarity) {
|
||||
return (1 + similarity) / 2;
|
||||
|
|
|
@ -23,8 +23,9 @@ public final class VectorUtil {
|
|||
private VectorUtil() {}
|
||||
|
||||
/**
|
||||
* Returns the vector dot product of the two vectors. IllegalArgumentException is thrown if the
|
||||
* vectors' dimensions differ.
|
||||
* Returns the vector dot product of the two vectors.
|
||||
*
|
||||
* @throws IllegalArgumentException if the vectors' dimensions differ.
|
||||
*/
|
||||
public static float dotProduct(float[] a, float[] b) {
|
||||
if (a.length != b.length) {
|
||||
|
@ -95,8 +96,35 @@ public final class VectorUtil {
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns the sum of squared differences of the two vectors. IllegalArgumentException is thrown
|
||||
* if the vectors' dimensions differ.
|
||||
* Returns the cosine similarity between the two vectors.
|
||||
*
|
||||
* @throws IllegalArgumentException if the vectors' dimensions differ.
|
||||
*/
|
||||
public static float cosine(float[] v1, float[] v2) {
|
||||
if (v1.length != v2.length) {
|
||||
throw new IllegalArgumentException(
|
||||
"vector dimensions differ: " + v1.length + "!=" + v2.length);
|
||||
}
|
||||
|
||||
float sum = 0.0f;
|
||||
float norm1 = 0.0f;
|
||||
float norm2 = 0.0f;
|
||||
int dim = v1.length;
|
||||
|
||||
for (int i = 0; i < dim; i++) {
|
||||
float elem1 = v1[i];
|
||||
float elem2 = v2[i];
|
||||
sum += elem1 * elem2;
|
||||
norm1 += elem1 * elem1;
|
||||
norm2 += elem2 * elem2;
|
||||
}
|
||||
return (float) (sum / Math.sqrt(norm1 * norm2));
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the sum of squared differences of the two vectors.
|
||||
*
|
||||
* @throws IllegalArgumentException if the vectors' dimensions differ.
|
||||
*/
|
||||
public static float squareDistance(float[] v1, float[] v2) {
|
||||
if (v1.length != v2.length) {
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
package org.apache.lucene.search;
|
||||
|
||||
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
|
||||
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
||||
|
@ -234,10 +235,65 @@ public class TestKnnVectorQuery extends LuceneTestCase {
|
|||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)) = 0.5, then
|
||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
||||
* normalized by (1 + x) /2.
|
||||
*/
|
||||
float maxAtZero = 0.99029f;
|
||||
float maxAtZero =
|
||||
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
||||
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
|
||||
|
||||
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
|
||||
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
|
||||
* normalized by (1 + x) /2
|
||||
*/
|
||||
float expected =
|
||||
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
|
||||
assertEquals(expected, scorer.getMaxScore(2), 0);
|
||||
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(0, it.nextDoc());
|
||||
// doc 0 has (1, 1)
|
||||
assertEquals(maxAtZero, scorer.score(), 0.0001);
|
||||
assertEquals(1, it.advance(1));
|
||||
assertEquals(expected, scorer.score(), 0);
|
||||
assertEquals(2, it.nextDoc());
|
||||
// since topK was 3
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testScoreCosine() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 1; j <= 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", new float[] {j, j * j}, COSINE));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
assertEquals(1, reader.leaves().size());
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(reader);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
|
||||
// prior to advancing, score is undefined
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
|
||||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
|
||||
* normalized by (1 + x) /2.
|
||||
*/
|
||||
float maxAtZero =
|
||||
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
|
||||
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
|
||||
|
||||
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
|
||||
|
|
|
@ -71,6 +71,33 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
|
||||
}
|
||||
|
||||
public void testBasicCosine() {
|
||||
assertEquals(
|
||||
0.11952f, VectorUtil.cosine(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), DELTA);
|
||||
}
|
||||
|
||||
public void testSelfCosine() {
|
||||
// the dot product of a vector with itself is always equal to 1
|
||||
float[] v = randomVector();
|
||||
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
|
||||
}
|
||||
|
||||
public void testOrthogonalCosine() {
|
||||
// the cosine of two perpendicular vectors is 0
|
||||
float[] v = new float[2];
|
||||
v[0] = random().nextInt(100);
|
||||
v[1] = random().nextInt(100);
|
||||
float[] u = new float[2];
|
||||
u[0] = v[1];
|
||||
u[1] = -v[0];
|
||||
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
|
||||
}
|
||||
|
||||
public void testCosineThrowsForDimensionMismatch() {
|
||||
float[] v = {1, 0, 0}, u = {0, 1};
|
||||
expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v));
|
||||
}
|
||||
|
||||
public void testNormalize() {
|
||||
float[] v = randomVector();
|
||||
v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes
|
||||
|
|
|
@ -957,7 +957,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
// enumerators
|
||||
assertEquals(0, VectorSimilarityFunction.EUCLIDEAN.ordinal());
|
||||
assertEquals(1, VectorSimilarityFunction.DOT_PRODUCT.ordinal());
|
||||
assertEquals(2, VectorSimilarityFunction.values().length);
|
||||
assertEquals(2, VectorSimilarityFunction.COSINE.ordinal());
|
||||
assertEquals(3, VectorSimilarityFunction.values().length);
|
||||
}
|
||||
|
||||
public void testAdvance() throws Exception {
|
||||
|
|
Loading…
Reference in New Issue