Add support for similarity-based vector searches (#12679)

### Description

Background in #12579

Add support for getting "all vectors within a radius" as opposed to getting the "topK closest vectors" in the current system

### Considerations

I've tried to keep this change minimal and non-invasive by not modifying any APIs and re-using existing HNSW graphs -- changing the graph traversal and result collection criteria to:
1. Visit all nodes (reachable from the entry node in the last level) that are within an outer "traversal" radius
2. Collect all nodes that are within an inner "result" radius

### Advantages

1. Queries that have a high number of "relevant" results will get all of those (not limited by `topK`)
2. Conversely, arbitrary queries where many results are not "relevant" will not waste time in getting all `topK` (when some of them will be removed later)
3. Results of HNSW searches need not be sorted - and we can store them in a plain list as opposed to min-max heaps (saving on `heapify` calls). Merging results from segments is also cheaper, where we just concatenate results as opposed to calculating the index-level `topK`

On a higher level, finding `topK` results needed HNSW searches to happen in `#rewrite` because of an interdependence of results between segments - where we want to find the index-level `topK` from multiple segment-level results. This is kind of against Lucene's concept of segments being independently searchable sub-indexes?

Moreover, we needed explicit concurrency (#12160) to perform these in parallel, and these shortcomings would be naturally overcome with the new objective of finding "all vectors within a radius" - inherently independent of results from another segment (so we can move searches to a more fitting place?)

### Caveats

I could not find much precedent in using HNSW graphs this way (or even the radius-based search for that matter - please add links to existing work if someone is aware) and consequently marked all classes as `@lucene.experimental`

For now I have re-used lots of functionality from `AbstractKnnVectorQuery` to keep this minimal, but if the use-case is accepted more widely we can look into writing more suitable queries (as mentioned above briefly)
This commit is contained in:
Kaival Parikh 2023-12-12 00:48:36 +05:30 committed by GitHub
parent 1630ed4bd8
commit cd195980ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1403 additions and 1 deletions

View File

@ -171,7 +171,11 @@ API Changes
New Features
---------------------
(No changes)
* GITHUB#12679: Add support for similarity-based vector searches using [Byte|Float]VectorSimilarityQuery. Uses a new
VectorSimilarityCollector to find all vectors scoring above a `resultSimilarity` while traversing the HNSW graph till
better-scoring nodes are available, or the best candidate is below a score of `traversalSimilarity` in the lowest
level. (Aditya Prakash, Kaival Parikh)
Improvements
---------------------

View File

@ -0,0 +1,288 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
/**
* Search for all (approximate) vectors above a similarity threshold.
*
* @lucene.experimental
*/
abstract class AbstractVectorSimilarityQuery extends Query {
protected final String field;
protected final float traversalSimilarity, resultSimilarity;
protected final Query filter;
/**
* Search for all (approximate) vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a vector field.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
AbstractVectorSimilarityQuery(
String field, float traversalSimilarity, float resultSimilarity, Query filter) {
if (traversalSimilarity > resultSimilarity) {
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
}
this.field = Objects.requireNonNull(field, "field");
this.traversalSimilarity = traversalSimilarity;
this.resultSimilarity = resultSimilarity;
this.filter = filter;
}
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
protected abstract TopDocs approximateSearch(
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
return new Weight(this) {
final Weight filterWeight =
filter == null
? null
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
if (filterWeight != null) {
Scorer filterScorer = filterWeight.scorer(context);
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) {
return Explanation.noMatch("Doc does not match the filter");
}
}
VectorScorer scorer = createVectorScorer(context);
if (scorer == null) {
return Explanation.noMatch("Not indexed as the correct vector field");
} else if (scorer.advanceExact(doc)) {
float score = scorer.score();
if (score >= resultSimilarity) {
return Explanation.match(boost * score, "Score above threshold");
} else {
return Explanation.noMatch("Score below threshold");
}
} else {
return Explanation.noMatch("No vector found for doc");
}
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
LeafReader leafReader = context.reader();
Bits liveDocs = leafReader.getLiveDocs();
// If there is no filter
if (filterWeight == null) {
// Return exhaustive results
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}
Scorer scorer = filterWeight.scorer(context);
if (scorer == null) {
// If the filter does not match any documents
return null;
}
BitSet acceptDocs;
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
// If there are no deletions, and matching docs are already cached
acceptDocs = bitSetIterator.getBitSet();
} else {
// Else collect all matching docs
FilteredDocIdSetIterator filtered =
new FilteredDocIdSetIterator(scorer.iterator()) {
@Override
protected boolean match(int doc) {
return liveDocs == null || liveDocs.get(doc);
}
};
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
}
int cardinality = acceptDocs.cardinality();
if (cardinality == 0) {
// If there are no live matching docs
return null;
}
// Perform an approximate search
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
// If the limit was exhausted
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
// Return a lazy-loading iterator
return VectorSimilarityScorer.fromAcceptDocs(
this,
boost,
createVectorScorer(context),
new BitSetIterator(acceptDocs, cardinality),
resultSimilarity);
} else {
// Return an iterator over the collected results
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
}
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object o) {
return sameClassAs(o)
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field)
&& Float.compare(
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity)
== 0
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity)
== 0
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter);
}
@Override
public int hashCode() {
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
}
private static class VectorSimilarityScorer extends Scorer {
final DocIdSetIterator iterator;
final float[] cachedScore;
VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) {
super(weight);
this.iterator = iterator;
this.cachedScore = cachedScore;
}
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
// Sort in ascending order of docid
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
float[] cachedScore = new float[1];
DocIdSetIterator iterator =
new DocIdSetIterator() {
int index = -1;
@Override
public int docID() {
if (index < 0) {
return -1;
} else if (index >= scoreDocs.length) {
return NO_MORE_DOCS;
} else {
cachedScore[0] = boost * scoreDocs[index].score;
return scoreDocs[index].doc;
}
}
@Override
public int nextDoc() {
index++;
return docID();
}
@Override
public int advance(int target) {
index =
Arrays.binarySearch(
scoreDocs,
new ScoreDoc(target, 0),
Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
if (index < 0) {
index = -1 - index;
}
return docID();
}
@Override
public long cost() {
return scoreDocs.length;
}
};
return new VectorSimilarityScorer(weight, iterator, cachedScore);
}
static VectorSimilarityScorer fromAcceptDocs(
Weight weight,
float boost,
VectorScorer scorer,
DocIdSetIterator acceptDocs,
float threshold) {
float[] cachedScore = new float[1];
DocIdSetIterator iterator =
new FilteredDocIdSetIterator(acceptDocs) {
@Override
protected boolean match(int doc) throws IOException {
// Compute the dot product
float score = scorer.score();
cachedScore[0] = score * boost;
return score >= threshold;
}
};
return new VectorSimilarityScorer(weight, iterator, cachedScore);
}
@Override
public int docID() {
return iterator.docID();
}
@Override
public DocIdSetIterator iterator() {
return iterator;
}
@Override
public float getMaxScore(int upTo) {
return Float.POSITIVE_INFINITY;
}
@Override
public float score() {
return cachedScore[0];
}
}
}

View File

@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
/**
* Search for all (approximate) byte vectors above a similarity threshold.
*
* @lucene.experimental
*/
public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
private final byte[] target;
/**
* Search for all (approximate) byte vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
* @param target the target of the search.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
public ByteVectorSimilarityQuery(
String field,
byte[] target,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
super(field, traversalSimilarity, resultSimilarity, filter);
this.target = Objects.requireNonNull(target, "target");
}
/**
* Search for all (approximate) byte vectors above a similarity threshold using {@link
* VectorSimilarityCollector}.
*
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
* @param target the target of the search.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
*/
public ByteVectorSimilarityQuery(
String field, byte[] target, float traversalSimilarity, float resultSimilarity) {
this(field, target, traversalSimilarity, resultSimilarity, null);
}
/**
* Search for all (approximate) byte vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
* @param target the target of the search.
* @param resultSimilarity similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
public ByteVectorSimilarityQuery(
String field, byte[] target, float resultSimilarity, Query filter) {
this(field, target, resultSimilarity, resultSimilarity, filter);
}
/**
* Search for all (approximate) byte vectors above a similarity threshold using {@link
* VectorSimilarityCollector}.
*
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
* @param target the target of the search.
* @param resultSimilarity similarity score for result collection.
*/
public ByteVectorSimilarityQuery(String field, byte[] target, float resultSimilarity) {
this(field, target, resultSimilarity, resultSimilarity, null);
}
@Override
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) {
return null;
}
return VectorScorer.create(context, fi, target);
}
@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}
@Override
public String toString(String field) {
return String.format(
Locale.ROOT,
"%s[field=%s target=[%d...] traversalSimilarity=%f resultSimilarity=%f filter=%s]",
getClass().getSimpleName(),
field,
target[0],
traversalSimilarity,
resultSimilarity,
filter);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o)
&& super.equals(o)
&& Arrays.equals(target, ((ByteVectorSimilarityQuery) o).target);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(target);
return result;
}
}

View File

@ -0,0 +1,146 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.VectorUtil;
/**
* Search for all (approximate) float vectors above a similarity threshold.
*
* @lucene.experimental
*/
public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
private final float[] target;
/**
* Search for all (approximate) float vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
* @param target the target of the search.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
public FloatVectorSimilarityQuery(
String field,
float[] target,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
super(field, traversalSimilarity, resultSimilarity, filter);
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
}
/**
* Search for all (approximate) float vectors above a similarity threshold using {@link
* VectorSimilarityCollector}.
*
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
* @param target the target of the search.
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
*/
public FloatVectorSimilarityQuery(
String field, float[] target, float traversalSimilarity, float resultSimilarity) {
this(field, target, traversalSimilarity, resultSimilarity, null);
}
/**
* Search for all (approximate) float vectors above a similarity threshold using {@link
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
* the filter, and then falls back to exact search if results are incomplete.
*
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
* @param target the target of the search.
* @param resultSimilarity similarity score for result collection.
* @param filter a filter applied before the vector search.
*/
public FloatVectorSimilarityQuery(
String field, float[] target, float resultSimilarity, Query filter) {
this(field, target, resultSimilarity, resultSimilarity, filter);
}
/**
* Search for all (approximate) float vectors above a similarity threshold using {@link
* VectorSimilarityCollector}.
*
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
* @param target the target of the search.
* @param resultSimilarity similarity score for result collection.
*/
public FloatVectorSimilarityQuery(String field, float[] target, float resultSimilarity) {
this(field, target, resultSimilarity, resultSimilarity, null);
}
@Override
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
@SuppressWarnings("resource")
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
return null;
}
return VectorScorer.create(context, fi, target);
}
@Override
@SuppressWarnings("resource")
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
throws IOException {
KnnCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
return collector.topDocs();
}
@Override
public String toString(String field) {
return String.format(
Locale.ROOT,
"%s[field=%s target=[%f...] traversalSimilarity=%f resultSimilarity=%f filter=%s]",
getClass().getSimpleName(),
field,
target[0],
traversalSimilarity,
resultSimilarity,
filter);
}
@Override
public boolean equals(Object o) {
return sameClassAs(o)
&& super.equals(o)
&& Arrays.equals(target, ((FloatVectorSimilarityQuery) o).target);
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + Arrays.hashCode(target);
return result;
}
}

View File

@ -0,0 +1,78 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.ArrayList;
import java.util.List;
/**
* Perform a similarity-based graph search.
*
* @lucene.experimental
*/
class VectorSimilarityCollector extends AbstractKnnCollector {
private final float traversalSimilarity, resultSimilarity;
private float maxSimilarity;
private final List<ScoreDoc> scoreDocList;
/**
* Perform a similarity-based graph search. The graph is traversed till better scoring nodes are
* available, or the best candidate is below {@link #traversalSimilarity}. All traversed nodes
* above {@link #resultSimilarity} are collected.
*
* @param traversalSimilarity (lower) similarity score for graph traversal.
* @param resultSimilarity (higher) similarity score for result collection.
* @param visitLimit limit on number of nodes to visit.
*/
public VectorSimilarityCollector(
float traversalSimilarity, float resultSimilarity, long visitLimit) {
super(1, visitLimit);
if (traversalSimilarity > resultSimilarity) {
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
}
this.traversalSimilarity = traversalSimilarity;
this.resultSimilarity = resultSimilarity;
this.maxSimilarity = Float.NEGATIVE_INFINITY;
this.scoreDocList = new ArrayList<>();
}
@Override
public boolean collect(int docId, float similarity) {
maxSimilarity = Math.max(maxSimilarity, similarity);
if (similarity >= resultSimilarity) {
scoreDocList.add(new ScoreDoc(docId, similarity));
}
return true;
}
@Override
public float minCompetitiveSimilarity() {
return Math.min(traversalSimilarity, maxSimilarity);
}
@Override
public TopDocs topDocs() {
// Results are not returned in a sorted order to prevent unnecessary calculations (because we do
// not need to maintain the topK)
TotalHits.Relation relation =
earlyTerminated()
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
return new TopDocs(
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
}
}

View File

@ -0,0 +1,516 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.IntStream;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.IntField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
abstract class BaseVectorSimilarityQueryTestCase<
V, F extends Field, Q extends AbstractVectorSimilarityQuery>
extends LuceneTestCase {
String vectorField, idField;
VectorSimilarityFunction function;
int numDocs, dim;
abstract V getRandomVector(int dim);
abstract float compare(V vector1, V vector2);
abstract boolean checkEquals(V vector1, V vector2);
abstract F getVectorField(String name, V vector, VectorSimilarityFunction function);
abstract Q getVectorQuery(
String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter);
abstract Q getThrowingVectorQuery(
String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter);
public void testEquals() {
String field1 = "f1", field2 = "f2";
V vector1 = getRandomVector(dim);
V vector2;
do {
vector2 = getRandomVector(dim);
} while (checkEquals(vector1, vector2));
float traversalSimilarity1 = 0.3f, traversalSimilarity2 = 0.4f;
float resultSimilarity1 = 0.4f, resultSimilarity2 = 0.5f;
Query filter1 = new TermQuery(new Term("t1", "v1"));
Query filter2 = new TermQuery(new Term("t2", "v2"));
Query query = getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1);
// Everything is equal
assertEquals(
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1));
// Null check
assertNotEquals(query, null);
// Different field
assertNotEquals(
query, getVectorQuery(field2, vector1, traversalSimilarity1, resultSimilarity1, filter1));
// Different vector
assertNotEquals(
query, getVectorQuery(field1, vector2, traversalSimilarity1, resultSimilarity1, filter1));
// Different traversalSimilarity
assertNotEquals(
query, getVectorQuery(field1, vector1, traversalSimilarity2, resultSimilarity1, filter1));
// Different resultSimilarity
assertNotEquals(
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity2, filter1));
// Different filter
assertNotEquals(
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter2));
}
public void testEmptyIndex() throws IOException {
// Do not index any vectors
numDocs = 0;
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
// Check that no vectors are found
assertEquals(0, searcher.count(query));
}
}
public void testExtremes() throws IOException {
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
// All vectors are above -Infinity
Query query1 =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
// Check that all vectors are found
assertEquals(numDocs, searcher.count(query1));
// No vectors are above +Infinity
Query query2 =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.POSITIVE_INFINITY,
Float.POSITIVE_INFINITY,
null);
// Check that no vectors are found
assertEquals(0, searcher.count(query2));
}
}
public void testRandomFilter() throws IOException {
// Filter a sub-range from 0 to numDocs
int startIndex = random().nextInt(numDocs);
int endIndex = random().nextInt(startIndex, numDocs);
Query filter = IntField.newRangeQuery(idField, startIndex, endIndex);
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
filter);
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
for (ScoreDoc scoreDoc : scoreDocs) {
int id = getId(searcher, scoreDoc.doc);
// Check that returned document is in selected range
assertTrue(id >= startIndex && id <= endIndex);
}
// Check that all filtered vectors are found
assertEquals(endIndex - startIndex + 1, scoreDocs.length);
}
}
public void testFilterWithNoMatches() throws IOException {
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
// Non-existent field
Query filter1 = new TermQuery(new Term("random_field", "random_value"));
Query query1 =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
filter1);
// Check that no vectors are found
assertEquals(0, searcher.count(query1));
// Field exists, but value of -1 is not indexed
Query filter2 = IntField.newExactQuery(idField, -1);
Query query2 =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
filter2);
// Check that no vectors are found
assertEquals(0, searcher.count(query2));
}
}
public void testDimensionMismatch() throws IOException {
// Different dimension
int newDim = atLeast(dim + 1);
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(
vectorField,
getRandomVector(newDim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
// Check that an exception for differing dimensions is thrown
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.count(query));
assertEquals(
String.format(
Locale.ROOT,
"vector query dimension: %d differs from field dimension: %d",
newDim,
dim),
e.getMessage());
}
}
public void testNonVectorsField() throws IOException {
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
// Non-existent field
Query query1 =
getVectorQuery(
"random_field",
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
assertEquals(0, searcher.count(query1));
// Indexed as int field
Query query2 =
getVectorQuery(
idField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
assertEquals(0, searcher.count(query2));
}
}
public void testSomeDeletes() throws IOException {
// Delete a sub-range from 0 to numDocs
int startIndex = random().nextInt(numDocs);
int endIndex = random().nextInt(startIndex, numDocs);
Query delete = IntField.newRangeQuery(idField, startIndex, endIndex);
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) {
w.deleteDocuments(delete);
w.commit();
try (IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
for (ScoreDoc scoreDoc : scoreDocs) {
int id = getId(searcher, scoreDoc.doc);
// Check that returned document is not deleted
assertFalse(id >= startIndex && id <= endIndex);
}
// Check that all live docs are returned
assertEquals(numDocs - endIndex + startIndex - 1, scoreDocs.length);
}
}
}
public void testAllDeletes() throws IOException {
try (Directory dir = getIndexStore(getRandomVectors(numDocs, dim));
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
// Delete all documents
w.deleteDocuments(new MatchAllDocsQuery());
w.commit();
try (IndexReader reader = DirectoryReader.open(dir)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
// Check that no vectors are found
assertEquals(0, searcher.count(query));
}
}
}
public void testBoostQuery() throws IOException {
// Define the boost and allowed delta
float boost = random().nextFloat(5, 10);
float delta = 1e-3f;
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query1 =
getVectorQuery(
vectorField,
getRandomVector(dim),
Float.NEGATIVE_INFINITY,
Float.NEGATIVE_INFINITY,
null);
ScoreDoc[] scoreDocs1 = searcher.search(query1, numDocs).scoreDocs;
Query query2 = new BoostQuery(query1, boost);
ScoreDoc[] scoreDocs2 = searcher.search(query2, numDocs).scoreDocs;
// Check that all docs are identical, with boosted scores
assertEquals(scoreDocs1.length, scoreDocs2.length);
for (int i = 0; i < scoreDocs1.length; i++) {
assertEquals(scoreDocs1[i].doc, scoreDocs2[i].doc);
assertEquals(boost * scoreDocs1[i].score, scoreDocs2[i].score, delta);
}
}
}
public void testVectorsAboveSimilarity() throws IOException {
// Pick number of docs to accept
int numAccepted = random().nextInt(numDocs / 3, numDocs / 2);
float delta = 1e-3f;
V[] vectors = getRandomVectors(numDocs, dim);
V queryVector = getRandomVector(dim);
// Find score above which we get (at least) numAccepted vectors
float resultSimilarity = getSimilarity(vectors, queryVector, numAccepted);
// Cache scores of vectors
Map<Integer, Float> scores = new HashMap<>();
for (int i = 0; i < numDocs; i++) {
float score = compare(queryVector, vectors[i]);
if (score >= resultSimilarity) {
scores.put(i, score);
}
}
try (Directory indexStore = getIndexStore(vectors);
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getVectorQuery(vectorField, queryVector, Float.NEGATIVE_INFINITY, resultSimilarity, null);
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
for (ScoreDoc scoreDoc : scoreDocs) {
int id = getId(searcher, scoreDoc.doc);
// Check that the collected result is above accepted similarity
assertTrue(scores.containsKey(id));
// Check that the score is correct
assertEquals(scores.get(id), scoreDoc.score, delta);
}
// Check that all results are collected
assertEquals(scores.size(), scoreDocs.length);
}
}
public void testFallbackToExact() throws IOException {
// Restrictive filter, along with similarity to visit a large number of nodes
int numFiltered = random().nextInt(numDocs / 10, numDocs / 5);
int targetVisited = random().nextInt(numFiltered * 2, numDocs);
V[] vectors = getRandomVectors(numDocs, dim);
V queryVector = getRandomVector(dim);
float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited);
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
try (Directory indexStore = getIndexStore(vectors);
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getThrowingVectorQuery(
vectorField, queryVector, resultSimilarity, resultSimilarity, filter);
// Falls back to exact search
expectThrows(UnsupportedOperationException.class, () -> searcher.count(query));
}
}
public void testApproximate() throws IOException {
// Non-restrictive filter, along with similarity to visit a small number of nodes
int numFiltered = random().nextInt((numDocs * 4) / 5, numDocs);
int targetVisited = random().nextInt(numFiltered / 10, numFiltered / 8);
V[] vectors = getRandomVectors(numDocs, dim);
V queryVector = getRandomVector(dim);
float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited);
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
try (Directory indexStore = getIndexStore(vectors);
IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) {
// Force merge because smaller segments have few filtered docs and often fall back to exact
// search, making this test flaky
w.forceMerge(1);
w.commit();
try (IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
Query query =
getThrowingVectorQuery(
vectorField, queryVector, resultSimilarity, resultSimilarity, filter);
// Does not fall back to exact search
assertTrue(searcher.count(query) <= numFiltered);
}
}
}
private float getSimilarity(V[] vectors, V queryVector, int targetVisited) {
assertTrue(targetVisited >= 0 && targetVisited <= numDocs);
if (targetVisited == 0) {
return Float.POSITIVE_INFINITY;
}
float[] scores = new float[numDocs];
for (int i = 0; i < numDocs; i++) {
scores[i] = compare(queryVector, vectors[i]);
}
Arrays.sort(scores);
return scores[numDocs - targetVisited];
}
private int[] getFiltered(int numFiltered) {
Set<Integer> accepted = new HashSet<>();
for (int i = 0; i < numFiltered; ) {
int index = random().nextInt(numDocs);
if (!accepted.contains(index)) {
accepted.add(index);
i++;
}
}
return accepted.stream().mapToInt(Integer::intValue).toArray();
}
private int getId(IndexSearcher searcher, int doc) throws IOException {
return Objects.requireNonNull(searcher.storedFields().document(doc).getField(idField))
.numericValue()
.intValue();
}
@SuppressWarnings("unchecked")
V[] getRandomVectors(int numDocs, int dim) {
return (V[]) IntStream.range(0, numDocs).mapToObj(i -> getRandomVector(dim)).toArray();
}
@SafeVarargs
final Directory getIndexStore(V... vectors) throws IOException {
Directory dir = newDirectory();
try (RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
for (int i = 0; i < vectors.length; ++i) {
Document doc = new Document();
doc.add(getVectorField(vectorField, vectors[i], function));
doc.add(new IntField(idField, i, Field.Store.YES));
writer.addDocument(doc);
}
}
return dir;
}
}

View File

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.Arrays;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.TestVectorUtil;
import org.junit.Before;
public class TestByteVectorSimilarityQuery
extends BaseVectorSimilarityQueryTestCase<
byte[], KnnByteVectorField, ByteVectorSimilarityQuery> {
@Before
public void setup() {
vectorField = getClass().getSimpleName() + ":VectorField";
idField = getClass().getSimpleName() + ":IdField";
function = VectorSimilarityFunction.EUCLIDEAN;
numDocs = atLeast(100);
dim = atLeast(50);
}
@Override
byte[] getRandomVector(int dim) {
return TestVectorUtil.randomVectorBytes(dim);
}
@Override
float compare(byte[] vector1, byte[] vector2) {
return function.compare(vector1, vector2);
}
@Override
boolean checkEquals(byte[] vector1, byte[] vector2) {
return Arrays.equals(vector1, vector2);
}
@Override
KnnByteVectorField getVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
return new KnnByteVectorField(name, vector, function);
}
@Override
ByteVectorSimilarityQuery getVectorQuery(
String field,
byte[] vector,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
return new ByteVectorSimilarityQuery(
field, vector, traversalSimilarity, resultSimilarity, filter);
}
@Override
ByteVectorSimilarityQuery getThrowingVectorQuery(
String field,
byte[] vector,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
return new ByteVectorSimilarityQuery(
field, vector, traversalSimilarity, resultSimilarity, filter) {
@Override
VectorScorer createVectorScorer(LeafReaderContext context) {
throw new UnsupportedOperationException();
}
};
}
}

View File

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.util.Arrays;
import org.apache.lucene.document.KnnFloatVectorField;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.TestVectorUtil;
import org.junit.Before;
public class TestFloatVectorSimilarityQuery
extends BaseVectorSimilarityQueryTestCase<
float[], KnnFloatVectorField, FloatVectorSimilarityQuery> {
@Before
public void setup() {
vectorField = getClass().getSimpleName() + ":VectorField";
idField = getClass().getSimpleName() + ":IdField";
function = VectorSimilarityFunction.EUCLIDEAN;
numDocs = atLeast(100);
dim = atLeast(50);
}
@Override
float[] getRandomVector(int dim) {
return TestVectorUtil.randomVector(dim);
}
@Override
float compare(float[] vector1, float[] vector2) {
return function.compare(vector1, vector2);
}
@Override
boolean checkEquals(float[] vector1, float[] vector2) {
return Arrays.equals(vector1, vector2);
}
@Override
KnnFloatVectorField getVectorField(
String name, float[] vector, VectorSimilarityFunction function) {
return new KnnFloatVectorField(name, vector, function);
}
@Override
FloatVectorSimilarityQuery getVectorQuery(
String field,
float[] vector,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
return new FloatVectorSimilarityQuery(
field, vector, traversalSimilarity, resultSimilarity, filter);
}
@Override
FloatVectorSimilarityQuery getThrowingVectorQuery(
String field,
float[] vector,
float traversalSimilarity,
float resultSimilarity,
Query filter) {
return new FloatVectorSimilarityQuery(
field, vector, traversalSimilarity, resultSimilarity, filter) {
@Override
VectorScorer createVectorScorer(LeafReaderContext context) {
throw new UnsupportedOperationException();
}
};
}
}

View File

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import org.apache.lucene.tests.util.LuceneTestCase;
public class TestVectorSimilarityCollector extends LuceneTestCase {
public void testResultCollection() {
float traversalSimilarity = 0.3f, resultSimilarity = 0.5f;
VectorSimilarityCollector collector =
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, Integer.MAX_VALUE);
int[] nodes = {1, 5, 10, 4, 8, 3, 2, 6, 7, 9};
float[] scores = {0.1f, 0.2f, 0.3f, 0.5f, 0.2f, 0.6f, 0.9f, 0.3f, 0.7f, 0.8f};
float[] minCompetitiveSimilarities = new float[nodes.length];
for (int i = 0; i < nodes.length; i++) {
collector.collect(nodes[i], scores[i]);
minCompetitiveSimilarities[i] = collector.minCompetitiveSimilarity();
}
ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
int[] resultNodes = new int[scoreDocs.length];
float[] resultScores = new float[scoreDocs.length];
for (int i = 0; i < scoreDocs.length; i++) {
resultNodes[i] = scoreDocs[i].doc;
resultScores[i] = scoreDocs[i].score;
}
// All nodes above resultSimilarity appear in order of collection
assertArrayEquals(new int[] {4, 3, 2, 7, 9}, resultNodes);
assertArrayEquals(new float[] {0.5f, 0.6f, 0.9f, 0.7f, 0.8f}, resultScores, 1e-3f);
// Min competitive similarity is minimum of traversalSimilarity or best result encountered
assertArrayEquals(
new float[] {0.1f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f},
minCompetitiveSimilarities,
1e-3f);
}
}