LUCENE-10146: Add VectorSimilarityFunction.COSINE (#366)

This PR adds support for using cosine similarity with kNN vector fields.

It takes a simple approach and doesn't attempt optimizations like normalizing
the query vector in advance, or performing loop unrolling. The thinking is that
users who prioritize efficiency can normalize all vectors in advance and use
`VectorSimilarityFunction.DOT_PRODUCT`.
This commit is contained in:
Julie Tibshirani 2021-10-11 08:49:19 -07:00 committed by GitHub
parent ed69f6080f
commit f4861159c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 147 additions and 16 deletions

View File

@ -99,11 +99,12 @@ import org.apache.lucene.store.IndexOutput;
* <li>PointDimensionCount, PointNumBytes: these are non-zero only if the field is indexed as
* points, e.g. using {@link org.apache.lucene.document.LongPoint}
* <li>VectorDimension: it is non-zero if the field is indexed as vectors.
* <li>VectorDistFunction: a byte containing distance function used for similarity calculation.
* <li>VectorSimilarityFunction: a byte containing distance function used for similarity
* calculation.
* <ul>
* <li>0: no distance function is defined for this field.
* <li>1: EUCLIDEAN_HNSW distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>2: DOT_PRODUCT_HNSW score. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>0: EUCLIDEAN distance. ({@link VectorSimilarityFunction#EUCLIDEAN})
* <li>1: DOT_PRODUCT similarity. ({@link VectorSimilarityFunction#DOT_PRODUCT})
* <li>2: COSINE similarity. ({@link VectorSimilarityFunction#COSINE})
* </ul>
* </ul>
*

View File

@ -19,6 +19,7 @@ package org.apache.lucene.document;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.util.VectorUtil;
/**
* A field that contains a single floating-point numeric vector (or none) for each document. Vectors
@ -73,8 +74,8 @@ public class KnnVectorField extends Field {
/**
* Creates a numeric vector field. Fields are single-valued: each document has either one value or
* no value. Vectors of a single field share the same dimension and similarity function. Note that
* some strategies (notably dot-product) require values to be unit-length, which can be enforced
* using VectorUtil.l2Normalize(float[]).
* some strategies (like {@link VectorSimilarityFunction#DOT_PRODUCT}) require values to be
* unit-length, which can be enforced using {@link VectorUtil#l2normalize(float[])}.
*
* @param name field name
* @param vector value

View File

@ -16,8 +16,7 @@
*/
package org.apache.lucene.index;
import static org.apache.lucene.util.VectorUtil.dotProduct;
import static org.apache.lucene.util.VectorUtil.squareDistance;
import static org.apache.lucene.util.VectorUtil.*;
/**
* Vector similarity function; used in search to return top K most similar vectors to a target
@ -39,13 +38,31 @@ public enum VectorSimilarityFunction {
}
},
/** Dot product */
/**
* Dot product. NOTE: this similarity is intended as an optimized way to perform cosine
* similarity. In order to use it, all vectors must be of unit length, including both document and
* query vectors. Using dot product with vectors that are not unit length can result in errors or
* poor search results.
*/
DOT_PRODUCT {
@Override
public float compare(float[] v1, float[] v2) {
return dotProduct(v1, v2);
}
@Override
public float convertToScore(float similarity) {
return (1 + similarity) / 2;
}
},
/** Cosine similarity */
COSINE {
@Override
public float compare(float[] v1, float[] v2) {
return cosine(v1, v2);
}
@Override
public float convertToScore(float similarity) {
return (1 + similarity) / 2;

View File

@ -23,8 +23,9 @@ public final class VectorUtil {
private VectorUtil() {}
/**
* Returns the vector dot product of the two vectors. IllegalArgumentException is thrown if the
* vectors' dimensions differ.
* Returns the vector dot product of the two vectors.
*
* @throws IllegalArgumentException if the vectors' dimensions differ.
*/
public static float dotProduct(float[] a, float[] b) {
if (a.length != b.length) {
@ -95,8 +96,35 @@ public final class VectorUtil {
}
/**
* Returns the sum of squared differences of the two vectors. IllegalArgumentException is thrown
* if the vectors' dimensions differ.
* Returns the cosine similarity between the two vectors.
*
* @throws IllegalArgumentException if the vectors' dimensions differ.
*/
public static float cosine(float[] v1, float[] v2) {
if (v1.length != v2.length) {
throw new IllegalArgumentException(
"vector dimensions differ: " + v1.length + "!=" + v2.length);
}
float sum = 0.0f;
float norm1 = 0.0f;
float norm2 = 0.0f;
int dim = v1.length;
for (int i = 0; i < dim; i++) {
float elem1 = v1[i];
float elem2 = v2[i];
sum += elem1 * elem2;
norm1 += elem1 * elem1;
norm2 += elem2 * elem2;
}
return (float) (sum / Math.sqrt(norm1 * norm2));
}
/**
* Returns the sum of squared differences of the two vectors.
*
* @throws IllegalArgumentException if the vectors' dimensions differ.
*/
public static float squareDistance(float[] v1, float[] v2) {
if (v1.length != v2.length) {

View File

@ -17,6 +17,7 @@
package org.apache.lucene.search;
import static com.carrotsearch.randomizedtesting.RandomizedTest.frequently;
import static org.apache.lucene.index.VectorSimilarityFunction.COSINE;
import static org.apache.lucene.index.VectorSimilarityFunction.DOT_PRODUCT;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector;
@ -234,10 +235,65 @@ public class TestKnnVectorQuery extends LuceneTestCase {
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)) = 0.5, then
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero = 0.99029f;
float maxAtZero =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)
* is the closest to (2, 3)). This is ((2,3) * (2, 4) = 16) / (||2, 3|| * ||2, 4|| = sqrt(260)), then
* normalized by (1 + x) /2
*/
float expected =
(float) ((1 + (2 * 2 + 3 * 4) / Math.sqrt((2 * 2 + 3 * 3) * (2 * 2 + 4 * 4))) / 2);
assertEquals(expected, scorer.getMaxScore(2), 0);
assertEquals(expected, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(0, it.nextDoc());
// doc 0 has (1, 1)
assertEquals(maxAtZero, scorer.score(), 0.0001);
assertEquals(1, it.advance(1));
assertEquals(expected, scorer.score(), 0);
assertEquals(2, it.nextDoc());
// since topK was 3
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
}
}
}
public void testScoreCosine() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 1; j <= 5; j++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", new float[] {j, j * j}, COSINE));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is undefined
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, scorer::score);
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
/* maxAtZero = ((2,3) * (1, 1) = 5) / (||2, 3|| * ||1, 1|| = sqrt(26)), then
* normalized by (1 + x) /2.
*/
float maxAtZero =
(float) ((1 + (2 * 1 + 3 * 1) / Math.sqrt((2 * 2 + 3 * 3) * (1 * 1 + 1 * 1))) / 2);
assertEquals(maxAtZero, scorer.getMaxScore(0), 0.001);
/* max at 2 is actually the score for doc 1 which is the highest (since doc 1 vector (2, 4)

View File

@ -71,6 +71,33 @@ public class TestVectorUtil extends LuceneTestCase {
assertEquals(4 * l2(v), VectorUtil.squareDistance(u, v), DELTA);
}
public void testBasicCosine() {
assertEquals(
0.11952f, VectorUtil.cosine(new float[] {1, 2, 3}, new float[] {-10, 0, 5}), DELTA);
}
public void testSelfCosine() {
// the dot product of a vector with itself is always equal to 1
float[] v = randomVector();
assertEquals(1.0f, VectorUtil.cosine(v, v), DELTA);
}
public void testOrthogonalCosine() {
// the cosine of two perpendicular vectors is 0
float[] v = new float[2];
v[0] = random().nextInt(100);
v[1] = random().nextInt(100);
float[] u = new float[2];
u[0] = v[1];
u[1] = -v[0];
assertEquals(0, VectorUtil.cosine(u, v), DELTA);
}
public void testCosineThrowsForDimensionMismatch() {
float[] v = {1, 0, 0}, u = {0, 1};
expectThrows(IllegalArgumentException.class, () -> VectorUtil.cosine(u, v));
}
public void testNormalize() {
float[] v = randomVector();
v[random().nextInt(v.length)] = 1; // ensure vector is not all zeroes

View File

@ -957,7 +957,8 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
// enumerators
assertEquals(0, VectorSimilarityFunction.EUCLIDEAN.ordinal());
assertEquals(1, VectorSimilarityFunction.DOT_PRODUCT.ordinal());
assertEquals(2, VectorSimilarityFunction.values().length);
assertEquals(2, VectorSimilarityFunction.COSINE.ordinal());
assertEquals(3, VectorSimilarityFunction.values().length);
}
public void testAdvance() throws Exception {