mirror of https://github.com/apache/lucene.git
Add support for similarity-based vector searches (#12679)
### Description Background in #12579 Add support for getting "all vectors within a radius" as opposed to getting the "topK closest vectors" in the current system ### Considerations I've tried to keep this change minimal and non-invasive by not modifying any APIs and re-using existing HNSW graphs -- changing the graph traversal and result collection criteria to: 1. Visit all nodes (reachable from the entry node in the last level) that are within an outer "traversal" radius 2. Collect all nodes that are within an inner "result" radius ### Advantages 1. Queries that have a high number of "relevant" results will get all of those (not limited by `topK`) 2. Conversely, arbitrary queries where many results are not "relevant" will not waste time in getting all `topK` (when some of them will be removed later) 3. Results of HNSW searches need not be sorted - and we can store them in a plain list as opposed to min-max heaps (saving on `heapify` calls). Merging results from segments is also cheaper, where we just concatenate results as opposed to calculating the index-level `topK` On a higher level, finding `topK` results needed HNSW searches to happen in `#rewrite` because of an interdependence of results between segments - where we want to find the index-level `topK` from multiple segment-level results. This is kind of against Lucene's concept of segments being independently searchable sub-indexes? Moreover, we needed explicit concurrency (#12160) to perform these in parallel, and these shortcomings would be naturally overcome with the new objective of finding "all vectors within a radius" - inherently independent of results from another segment (so we can move searches to a more fitting place?) ### Caveats I could not find much precedent in using HNSW graphs this way (or even the radius-based search for that matter - please add links to existing work if someone is aware) and consequently marked all classes as `@lucene.experimental` For now I have re-used lots of functionality from `AbstractKnnVectorQuery` to keep this minimal, but if the use-case is accepted more widely we can look into writing more suitable queries (as mentioned above briefly)
This commit is contained in:
parent
1630ed4bd8
commit
cd195980ec
|
@ -171,7 +171,11 @@ API Changes
|
|||
|
||||
New Features
|
||||
---------------------
|
||||
(No changes)
|
||||
|
||||
* GITHUB#12679: Add support for similarity-based vector searches using [Byte|Float]VectorSimilarityQuery. Uses a new
|
||||
VectorSimilarityCollector to find all vectors scoring above a `resultSimilarity` while traversing the HNSW graph till
|
||||
better-scoring nodes are available, or the best candidate is below a score of `traversalSimilarity` in the lowest
|
||||
level. (Aditya Prakash, Kaival Parikh)
|
||||
|
||||
Improvements
|
||||
---------------------
|
||||
|
|
|
@ -0,0 +1,288 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.util.BitSet;
|
||||
import org.apache.lucene.util.BitSetIterator;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) vectors above a similarity threshold.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
abstract class AbstractVectorSimilarityQuery extends Query {
|
||||
protected final String field;
|
||||
protected final float traversalSimilarity, resultSimilarity;
|
||||
protected final Query filter;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
|
||||
* the filter, and then falls back to exact search if results are incomplete.
|
||||
*
|
||||
* @param field a field that has been indexed as a vector field.
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
* @param filter a filter applied before the vector search.
|
||||
*/
|
||||
AbstractVectorSimilarityQuery(
|
||||
String field, float traversalSimilarity, float resultSimilarity, Query filter) {
|
||||
if (traversalSimilarity > resultSimilarity) {
|
||||
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
|
||||
}
|
||||
this.field = Objects.requireNonNull(field, "field");
|
||||
this.traversalSimilarity = traversalSimilarity;
|
||||
this.resultSimilarity = resultSimilarity;
|
||||
this.filter = filter;
|
||||
}
|
||||
|
||||
abstract VectorScorer createVectorScorer(LeafReaderContext context) throws IOException;
|
||||
|
||||
protected abstract TopDocs approximateSearch(
|
||||
LeafReaderContext context, Bits acceptDocs, int visitLimit) throws IOException;
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
return new Weight(this) {
|
||||
final Weight filterWeight =
|
||||
filter == null
|
||||
? null
|
||||
: searcher.createWeight(searcher.rewrite(filter), ScoreMode.COMPLETE_NO_SCORES, 1);
|
||||
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
|
||||
if (filterWeight != null) {
|
||||
Scorer filterScorer = filterWeight.scorer(context);
|
||||
if (filterScorer == null || filterScorer.iterator().advance(doc) > doc) {
|
||||
return Explanation.noMatch("Doc does not match the filter");
|
||||
}
|
||||
}
|
||||
|
||||
VectorScorer scorer = createVectorScorer(context);
|
||||
if (scorer == null) {
|
||||
return Explanation.noMatch("Not indexed as the correct vector field");
|
||||
} else if (scorer.advanceExact(doc)) {
|
||||
float score = scorer.score();
|
||||
if (score >= resultSimilarity) {
|
||||
return Explanation.match(boost * score, "Score above threshold");
|
||||
} else {
|
||||
return Explanation.noMatch("Score below threshold");
|
||||
}
|
||||
} else {
|
||||
return Explanation.noMatch("No vector found for doc");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) throws IOException {
|
||||
@SuppressWarnings("resource")
|
||||
LeafReader leafReader = context.reader();
|
||||
Bits liveDocs = leafReader.getLiveDocs();
|
||||
|
||||
// If there is no filter
|
||||
if (filterWeight == null) {
|
||||
// Return exhaustive results
|
||||
TopDocs results = approximateSearch(context, liveDocs, Integer.MAX_VALUE);
|
||||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
|
||||
}
|
||||
|
||||
Scorer scorer = filterWeight.scorer(context);
|
||||
if (scorer == null) {
|
||||
// If the filter does not match any documents
|
||||
return null;
|
||||
}
|
||||
|
||||
BitSet acceptDocs;
|
||||
if (liveDocs == null && scorer.iterator() instanceof BitSetIterator bitSetIterator) {
|
||||
// If there are no deletions, and matching docs are already cached
|
||||
acceptDocs = bitSetIterator.getBitSet();
|
||||
} else {
|
||||
// Else collect all matching docs
|
||||
FilteredDocIdSetIterator filtered =
|
||||
new FilteredDocIdSetIterator(scorer.iterator()) {
|
||||
@Override
|
||||
protected boolean match(int doc) {
|
||||
return liveDocs == null || liveDocs.get(doc);
|
||||
}
|
||||
};
|
||||
acceptDocs = BitSet.of(filtered, leafReader.maxDoc());
|
||||
}
|
||||
|
||||
int cardinality = acceptDocs.cardinality();
|
||||
if (cardinality == 0) {
|
||||
// If there are no live matching docs
|
||||
return null;
|
||||
}
|
||||
|
||||
// Perform an approximate search
|
||||
TopDocs results = approximateSearch(context, acceptDocs, cardinality);
|
||||
|
||||
// If the limit was exhausted
|
||||
if (results.totalHits.relation == TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO) {
|
||||
// Return a lazy-loading iterator
|
||||
return VectorSimilarityScorer.fromAcceptDocs(
|
||||
this,
|
||||
boost,
|
||||
createVectorScorer(context),
|
||||
new BitSetIterator(acceptDocs, cardinality),
|
||||
resultSimilarity);
|
||||
} else {
|
||||
// Return an iterator over the collected results
|
||||
return VectorSimilarityScorer.fromScoreDocs(this, boost, results.scoreDocs);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o)
|
||||
&& Objects.equals(field, ((AbstractVectorSimilarityQuery) o).field)
|
||||
&& Float.compare(
|
||||
((AbstractVectorSimilarityQuery) o).traversalSimilarity, traversalSimilarity)
|
||||
== 0
|
||||
&& Float.compare(((AbstractVectorSimilarityQuery) o).resultSimilarity, resultSimilarity)
|
||||
== 0
|
||||
&& Objects.equals(filter, ((AbstractVectorSimilarityQuery) o).filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, traversalSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
private static class VectorSimilarityScorer extends Scorer {
|
||||
final DocIdSetIterator iterator;
|
||||
final float[] cachedScore;
|
||||
|
||||
VectorSimilarityScorer(Weight weight, DocIdSetIterator iterator, float[] cachedScore) {
|
||||
super(weight);
|
||||
this.iterator = iterator;
|
||||
this.cachedScore = cachedScore;
|
||||
}
|
||||
|
||||
static VectorSimilarityScorer fromScoreDocs(Weight weight, float boost, ScoreDoc[] scoreDocs) {
|
||||
// Sort in ascending order of docid
|
||||
Arrays.sort(scoreDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
|
||||
|
||||
float[] cachedScore = new float[1];
|
||||
DocIdSetIterator iterator =
|
||||
new DocIdSetIterator() {
|
||||
int index = -1;
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
if (index < 0) {
|
||||
return -1;
|
||||
} else if (index >= scoreDocs.length) {
|
||||
return NO_MORE_DOCS;
|
||||
} else {
|
||||
cachedScore[0] = boost * scoreDocs[index].score;
|
||||
return scoreDocs[index].doc;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
index++;
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) {
|
||||
index =
|
||||
Arrays.binarySearch(
|
||||
scoreDocs,
|
||||
new ScoreDoc(target, 0),
|
||||
Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
|
||||
if (index < 0) {
|
||||
index = -1 - index;
|
||||
}
|
||||
return docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return scoreDocs.length;
|
||||
}
|
||||
};
|
||||
|
||||
return new VectorSimilarityScorer(weight, iterator, cachedScore);
|
||||
}
|
||||
|
||||
static VectorSimilarityScorer fromAcceptDocs(
|
||||
Weight weight,
|
||||
float boost,
|
||||
VectorScorer scorer,
|
||||
DocIdSetIterator acceptDocs,
|
||||
float threshold) {
|
||||
float[] cachedScore = new float[1];
|
||||
DocIdSetIterator iterator =
|
||||
new FilteredDocIdSetIterator(acceptDocs) {
|
||||
@Override
|
||||
protected boolean match(int doc) throws IOException {
|
||||
// Compute the dot product
|
||||
float score = scorer.score();
|
||||
cachedScore[0] = score * boost;
|
||||
return score >= threshold;
|
||||
}
|
||||
};
|
||||
|
||||
return new VectorSimilarityScorer(weight, iterator, cachedScore);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return iterator.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return iterator;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int upTo) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return cachedScore[0];
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,145 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) byte vectors above a similarity threshold.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class ByteVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
||||
private final byte[] target;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) byte vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
|
||||
* the filter, and then falls back to exact search if results are incomplete.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
* @param filter a filter applied before the vector search.
|
||||
*/
|
||||
public ByteVectorSimilarityQuery(
|
||||
String field,
|
||||
byte[] target,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
super(field, traversalSimilarity, resultSimilarity, filter);
|
||||
this.target = Objects.requireNonNull(target, "target");
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) byte vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
*/
|
||||
public ByteVectorSimilarityQuery(
|
||||
String field, byte[] target, float traversalSimilarity, float resultSimilarity) {
|
||||
this(field, target, traversalSimilarity, resultSimilarity, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) byte vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
|
||||
* the filter, and then falls back to exact search if results are incomplete.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param resultSimilarity similarity score for result collection.
|
||||
* @param filter a filter applied before the vector search.
|
||||
*/
|
||||
public ByteVectorSimilarityQuery(
|
||||
String field, byte[] target, float resultSimilarity, Query filter) {
|
||||
this(field, target, resultSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) byte vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnByteVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param resultSimilarity similarity score for result collection.
|
||||
*/
|
||||
public ByteVectorSimilarityQuery(String field, byte[] target, float resultSimilarity) {
|
||||
this(field, target, resultSimilarity, resultSimilarity, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
|
||||
@SuppressWarnings("resource")
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorEncoding() != VectorEncoding.BYTE) {
|
||||
return null;
|
||||
}
|
||||
return VectorScorer.create(context, fi, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
|
||||
throws IOException {
|
||||
KnnCollector collector =
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
|
||||
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return String.format(
|
||||
Locale.ROOT,
|
||||
"%s[field=%s target=[%d...] traversalSimilarity=%f resultSimilarity=%f filter=%s]",
|
||||
getClass().getSimpleName(),
|
||||
field,
|
||||
target[0],
|
||||
traversalSimilarity,
|
||||
resultSimilarity,
|
||||
filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o)
|
||||
&& super.equals(o)
|
||||
&& Arrays.equals(target, ((ByteVectorSimilarityQuery) o).target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = super.hashCode();
|
||||
result = 31 * result + Arrays.hashCode(target);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,146 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Locale;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.FieldInfo;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorEncoding;
|
||||
import org.apache.lucene.util.Bits;
|
||||
import org.apache.lucene.util.VectorUtil;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) float vectors above a similarity threshold.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
public class FloatVectorSimilarityQuery extends AbstractVectorSimilarityQuery {
|
||||
private final float[] target;
|
||||
|
||||
/**
|
||||
* Search for all (approximate) float vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
|
||||
* the filter, and then falls back to exact search if results are incomplete.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
* @param filter a filter applied before the vector search.
|
||||
*/
|
||||
public FloatVectorSimilarityQuery(
|
||||
String field,
|
||||
float[] target,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
super(field, traversalSimilarity, resultSimilarity, filter);
|
||||
this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) float vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
*/
|
||||
public FloatVectorSimilarityQuery(
|
||||
String field, float[] target, float traversalSimilarity, float resultSimilarity) {
|
||||
this(field, target, traversalSimilarity, resultSimilarity, null);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) float vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}. If a filter is applied, it traverses as many nodes as the cost of
|
||||
* the filter, and then falls back to exact search if results are incomplete.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param resultSimilarity similarity score for result collection.
|
||||
* @param filter a filter applied before the vector search.
|
||||
*/
|
||||
public FloatVectorSimilarityQuery(
|
||||
String field, float[] target, float resultSimilarity, Query filter) {
|
||||
this(field, target, resultSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
/**
|
||||
* Search for all (approximate) float vectors above a similarity threshold using {@link
|
||||
* VectorSimilarityCollector}.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnFloatVectorField}.
|
||||
* @param target the target of the search.
|
||||
* @param resultSimilarity similarity score for result collection.
|
||||
*/
|
||||
public FloatVectorSimilarityQuery(String field, float[] target, float resultSimilarity) {
|
||||
this(field, target, resultSimilarity, resultSimilarity, null);
|
||||
}
|
||||
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) throws IOException {
|
||||
@SuppressWarnings("resource")
|
||||
FieldInfo fi = context.reader().getFieldInfos().fieldInfo(field);
|
||||
if (fi == null || fi.getVectorEncoding() != VectorEncoding.FLOAT32) {
|
||||
return null;
|
||||
}
|
||||
return VectorScorer.create(context, fi, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
@SuppressWarnings("resource")
|
||||
protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, int visitLimit)
|
||||
throws IOException {
|
||||
KnnCollector collector =
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit);
|
||||
context.reader().searchNearestVectors(field, target, collector, acceptDocs);
|
||||
return collector.topDocs();
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return String.format(
|
||||
Locale.ROOT,
|
||||
"%s[field=%s target=[%f...] traversalSimilarity=%f resultSimilarity=%f filter=%s]",
|
||||
getClass().getSimpleName(),
|
||||
field,
|
||||
target[0],
|
||||
traversalSimilarity,
|
||||
resultSimilarity,
|
||||
filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
return sameClassAs(o)
|
||||
&& super.equals(o)
|
||||
&& Arrays.equals(target, ((FloatVectorSimilarityQuery) o).target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result = super.hashCode();
|
||||
result = 31 * result + Arrays.hashCode(target);
|
||||
return result;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Perform a similarity-based graph search.
|
||||
*
|
||||
* @lucene.experimental
|
||||
*/
|
||||
class VectorSimilarityCollector extends AbstractKnnCollector {
|
||||
private final float traversalSimilarity, resultSimilarity;
|
||||
private float maxSimilarity;
|
||||
private final List<ScoreDoc> scoreDocList;
|
||||
|
||||
/**
|
||||
* Perform a similarity-based graph search. The graph is traversed till better scoring nodes are
|
||||
* available, or the best candidate is below {@link #traversalSimilarity}. All traversed nodes
|
||||
* above {@link #resultSimilarity} are collected.
|
||||
*
|
||||
* @param traversalSimilarity (lower) similarity score for graph traversal.
|
||||
* @param resultSimilarity (higher) similarity score for result collection.
|
||||
* @param visitLimit limit on number of nodes to visit.
|
||||
*/
|
||||
public VectorSimilarityCollector(
|
||||
float traversalSimilarity, float resultSimilarity, long visitLimit) {
|
||||
super(1, visitLimit);
|
||||
if (traversalSimilarity > resultSimilarity) {
|
||||
throw new IllegalArgumentException("traversalSimilarity should be <= resultSimilarity");
|
||||
}
|
||||
this.traversalSimilarity = traversalSimilarity;
|
||||
this.resultSimilarity = resultSimilarity;
|
||||
this.maxSimilarity = Float.NEGATIVE_INFINITY;
|
||||
this.scoreDocList = new ArrayList<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean collect(int docId, float similarity) {
|
||||
maxSimilarity = Math.max(maxSimilarity, similarity);
|
||||
if (similarity >= resultSimilarity) {
|
||||
scoreDocList.add(new ScoreDoc(docId, similarity));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float minCompetitiveSimilarity() {
|
||||
return Math.min(traversalSimilarity, maxSimilarity);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TopDocs topDocs() {
|
||||
// Results are not returned in a sorted order to prevent unnecessary calculations (because we do
|
||||
// not need to maintain the topK)
|
||||
TotalHits.Relation relation =
|
||||
earlyTerminated()
|
||||
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
|
||||
: TotalHits.Relation.EQUAL_TO;
|
||||
return new TopDocs(
|
||||
new TotalHits(visitedCount(), relation), scoreDocList.toArray(ScoreDoc[]::new));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,516 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
import java.util.stream.IntStream;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.IntField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
abstract class BaseVectorSimilarityQueryTestCase<
|
||||
V, F extends Field, Q extends AbstractVectorSimilarityQuery>
|
||||
extends LuceneTestCase {
|
||||
String vectorField, idField;
|
||||
VectorSimilarityFunction function;
|
||||
int numDocs, dim;
|
||||
|
||||
abstract V getRandomVector(int dim);
|
||||
|
||||
abstract float compare(V vector1, V vector2);
|
||||
|
||||
abstract boolean checkEquals(V vector1, V vector2);
|
||||
|
||||
abstract F getVectorField(String name, V vector, VectorSimilarityFunction function);
|
||||
|
||||
abstract Q getVectorQuery(
|
||||
String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter);
|
||||
|
||||
abstract Q getThrowingVectorQuery(
|
||||
String field, V vector, float traversalSimilarity, float resultSimilarity, Query filter);
|
||||
|
||||
public void testEquals() {
|
||||
String field1 = "f1", field2 = "f2";
|
||||
|
||||
V vector1 = getRandomVector(dim);
|
||||
V vector2;
|
||||
do {
|
||||
vector2 = getRandomVector(dim);
|
||||
} while (checkEquals(vector1, vector2));
|
||||
|
||||
float traversalSimilarity1 = 0.3f, traversalSimilarity2 = 0.4f;
|
||||
float resultSimilarity1 = 0.4f, resultSimilarity2 = 0.5f;
|
||||
|
||||
Query filter1 = new TermQuery(new Term("t1", "v1"));
|
||||
Query filter2 = new TermQuery(new Term("t2", "v2"));
|
||||
|
||||
Query query = getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1);
|
||||
|
||||
// Everything is equal
|
||||
assertEquals(
|
||||
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter1));
|
||||
|
||||
// Null check
|
||||
assertNotEquals(query, null);
|
||||
|
||||
// Different field
|
||||
assertNotEquals(
|
||||
query, getVectorQuery(field2, vector1, traversalSimilarity1, resultSimilarity1, filter1));
|
||||
|
||||
// Different vector
|
||||
assertNotEquals(
|
||||
query, getVectorQuery(field1, vector2, traversalSimilarity1, resultSimilarity1, filter1));
|
||||
|
||||
// Different traversalSimilarity
|
||||
assertNotEquals(
|
||||
query, getVectorQuery(field1, vector1, traversalSimilarity2, resultSimilarity1, filter1));
|
||||
|
||||
// Different resultSimilarity
|
||||
assertNotEquals(
|
||||
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity2, filter1));
|
||||
|
||||
// Different filter
|
||||
assertNotEquals(
|
||||
query, getVectorQuery(field1, vector1, traversalSimilarity1, resultSimilarity1, filter2));
|
||||
}
|
||||
|
||||
public void testEmptyIndex() throws IOException {
|
||||
// Do not index any vectors
|
||||
numDocs = 0;
|
||||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
|
||||
// Check that no vectors are found
|
||||
assertEquals(0, searcher.count(query));
|
||||
}
|
||||
}
|
||||
|
||||
public void testExtremes() throws IOException {
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
// All vectors are above -Infinity
|
||||
Query query1 =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
|
||||
// Check that all vectors are found
|
||||
assertEquals(numDocs, searcher.count(query1));
|
||||
|
||||
// No vectors are above +Infinity
|
||||
Query query2 =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.POSITIVE_INFINITY,
|
||||
Float.POSITIVE_INFINITY,
|
||||
null);
|
||||
|
||||
// Check that no vectors are found
|
||||
assertEquals(0, searcher.count(query2));
|
||||
}
|
||||
}
|
||||
|
||||
public void testRandomFilter() throws IOException {
|
||||
// Filter a sub-range from 0 to numDocs
|
||||
int startIndex = random().nextInt(numDocs);
|
||||
int endIndex = random().nextInt(startIndex, numDocs);
|
||||
Query filter = IntField.newRangeQuery(idField, startIndex, endIndex);
|
||||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
filter);
|
||||
|
||||
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
|
||||
for (ScoreDoc scoreDoc : scoreDocs) {
|
||||
int id = getId(searcher, scoreDoc.doc);
|
||||
|
||||
// Check that returned document is in selected range
|
||||
assertTrue(id >= startIndex && id <= endIndex);
|
||||
}
|
||||
// Check that all filtered vectors are found
|
||||
assertEquals(endIndex - startIndex + 1, scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
||||
public void testFilterWithNoMatches() throws IOException {
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
// Non-existent field
|
||||
Query filter1 = new TermQuery(new Term("random_field", "random_value"));
|
||||
Query query1 =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
filter1);
|
||||
|
||||
// Check that no vectors are found
|
||||
assertEquals(0, searcher.count(query1));
|
||||
|
||||
// Field exists, but value of -1 is not indexed
|
||||
Query filter2 = IntField.newExactQuery(idField, -1);
|
||||
Query query2 =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
filter2);
|
||||
|
||||
// Check that no vectors are found
|
||||
assertEquals(0, searcher.count(query2));
|
||||
}
|
||||
}
|
||||
|
||||
public void testDimensionMismatch() throws IOException {
|
||||
// Different dimension
|
||||
int newDim = atLeast(dim + 1);
|
||||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(newDim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
|
||||
// Check that an exception for differing dimensions is thrown
|
||||
IllegalArgumentException e =
|
||||
expectThrows(IllegalArgumentException.class, () -> searcher.count(query));
|
||||
assertEquals(
|
||||
String.format(
|
||||
Locale.ROOT,
|
||||
"vector query dimension: %d differs from field dimension: %d",
|
||||
newDim,
|
||||
dim),
|
||||
e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
public void testNonVectorsField() throws IOException {
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
// Non-existent field
|
||||
Query query1 =
|
||||
getVectorQuery(
|
||||
"random_field",
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
assertEquals(0, searcher.count(query1));
|
||||
|
||||
// Indexed as int field
|
||||
Query query2 =
|
||||
getVectorQuery(
|
||||
idField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
assertEquals(0, searcher.count(query2));
|
||||
}
|
||||
}
|
||||
|
||||
public void testSomeDeletes() throws IOException {
|
||||
// Delete a sub-range from 0 to numDocs
|
||||
int startIndex = random().nextInt(numDocs);
|
||||
int endIndex = random().nextInt(startIndex, numDocs);
|
||||
Query delete = IntField.newRangeQuery(idField, startIndex, endIndex);
|
||||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) {
|
||||
|
||||
w.deleteDocuments(delete);
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
|
||||
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
|
||||
for (ScoreDoc scoreDoc : scoreDocs) {
|
||||
int id = getId(searcher, scoreDoc.doc);
|
||||
|
||||
// Check that returned document is not deleted
|
||||
assertFalse(id >= startIndex && id <= endIndex);
|
||||
}
|
||||
// Check that all live docs are returned
|
||||
assertEquals(numDocs - endIndex + startIndex - 1, scoreDocs.length);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testAllDeletes() throws IOException {
|
||||
try (Directory dir = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) {
|
||||
// Delete all documents
|
||||
w.deleteDocuments(new MatchAllDocsQuery());
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(dir)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
|
||||
// Check that no vectors are found
|
||||
assertEquals(0, searcher.count(query));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testBoostQuery() throws IOException {
|
||||
// Define the boost and allowed delta
|
||||
float boost = random().nextFloat(5, 10);
|
||||
float delta = 1e-3f;
|
||||
|
||||
try (Directory indexStore = getIndexStore(getRandomVectors(numDocs, dim));
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query1 =
|
||||
getVectorQuery(
|
||||
vectorField,
|
||||
getRandomVector(dim),
|
||||
Float.NEGATIVE_INFINITY,
|
||||
Float.NEGATIVE_INFINITY,
|
||||
null);
|
||||
ScoreDoc[] scoreDocs1 = searcher.search(query1, numDocs).scoreDocs;
|
||||
|
||||
Query query2 = new BoostQuery(query1, boost);
|
||||
ScoreDoc[] scoreDocs2 = searcher.search(query2, numDocs).scoreDocs;
|
||||
|
||||
// Check that all docs are identical, with boosted scores
|
||||
assertEquals(scoreDocs1.length, scoreDocs2.length);
|
||||
for (int i = 0; i < scoreDocs1.length; i++) {
|
||||
assertEquals(scoreDocs1[i].doc, scoreDocs2[i].doc);
|
||||
assertEquals(boost * scoreDocs1[i].score, scoreDocs2[i].score, delta);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testVectorsAboveSimilarity() throws IOException {
|
||||
// Pick number of docs to accept
|
||||
int numAccepted = random().nextInt(numDocs / 3, numDocs / 2);
|
||||
float delta = 1e-3f;
|
||||
|
||||
V[] vectors = getRandomVectors(numDocs, dim);
|
||||
V queryVector = getRandomVector(dim);
|
||||
|
||||
// Find score above which we get (at least) numAccepted vectors
|
||||
float resultSimilarity = getSimilarity(vectors, queryVector, numAccepted);
|
||||
|
||||
// Cache scores of vectors
|
||||
Map<Integer, Float> scores = new HashMap<>();
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
float score = compare(queryVector, vectors[i]);
|
||||
if (score >= resultSimilarity) {
|
||||
scores.put(i, score);
|
||||
}
|
||||
}
|
||||
|
||||
try (Directory indexStore = getIndexStore(vectors);
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getVectorQuery(vectorField, queryVector, Float.NEGATIVE_INFINITY, resultSimilarity, null);
|
||||
|
||||
ScoreDoc[] scoreDocs = searcher.search(query, numDocs).scoreDocs;
|
||||
for (ScoreDoc scoreDoc : scoreDocs) {
|
||||
int id = getId(searcher, scoreDoc.doc);
|
||||
|
||||
// Check that the collected result is above accepted similarity
|
||||
assertTrue(scores.containsKey(id));
|
||||
|
||||
// Check that the score is correct
|
||||
assertEquals(scores.get(id), scoreDoc.score, delta);
|
||||
}
|
||||
|
||||
// Check that all results are collected
|
||||
assertEquals(scores.size(), scoreDocs.length);
|
||||
}
|
||||
}
|
||||
|
||||
public void testFallbackToExact() throws IOException {
|
||||
// Restrictive filter, along with similarity to visit a large number of nodes
|
||||
int numFiltered = random().nextInt(numDocs / 10, numDocs / 5);
|
||||
int targetVisited = random().nextInt(numFiltered * 2, numDocs);
|
||||
|
||||
V[] vectors = getRandomVectors(numDocs, dim);
|
||||
V queryVector = getRandomVector(dim);
|
||||
|
||||
float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited);
|
||||
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
|
||||
|
||||
try (Directory indexStore = getIndexStore(vectors);
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getThrowingVectorQuery(
|
||||
vectorField, queryVector, resultSimilarity, resultSimilarity, filter);
|
||||
|
||||
// Falls back to exact search
|
||||
expectThrows(UnsupportedOperationException.class, () -> searcher.count(query));
|
||||
}
|
||||
}
|
||||
|
||||
public void testApproximate() throws IOException {
|
||||
// Non-restrictive filter, along with similarity to visit a small number of nodes
|
||||
int numFiltered = random().nextInt((numDocs * 4) / 5, numDocs);
|
||||
int targetVisited = random().nextInt(numFiltered / 10, numFiltered / 8);
|
||||
|
||||
V[] vectors = getRandomVectors(numDocs, dim);
|
||||
V queryVector = getRandomVector(dim);
|
||||
|
||||
float resultSimilarity = getSimilarity(vectors, queryVector, targetVisited);
|
||||
Query filter = IntField.newSetQuery(idField, getFiltered(numFiltered));
|
||||
|
||||
try (Directory indexStore = getIndexStore(vectors);
|
||||
IndexWriter w = new IndexWriter(indexStore, newIndexWriterConfig())) {
|
||||
// Force merge because smaller segments have few filtered docs and often fall back to exact
|
||||
// search, making this test flaky
|
||||
w.forceMerge(1);
|
||||
w.commit();
|
||||
|
||||
try (IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
|
||||
Query query =
|
||||
getThrowingVectorQuery(
|
||||
vectorField, queryVector, resultSimilarity, resultSimilarity, filter);
|
||||
|
||||
// Does not fall back to exact search
|
||||
assertTrue(searcher.count(query) <= numFiltered);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private float getSimilarity(V[] vectors, V queryVector, int targetVisited) {
|
||||
assertTrue(targetVisited >= 0 && targetVisited <= numDocs);
|
||||
if (targetVisited == 0) {
|
||||
return Float.POSITIVE_INFINITY;
|
||||
}
|
||||
|
||||
float[] scores = new float[numDocs];
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
scores[i] = compare(queryVector, vectors[i]);
|
||||
}
|
||||
Arrays.sort(scores);
|
||||
|
||||
return scores[numDocs - targetVisited];
|
||||
}
|
||||
|
||||
private int[] getFiltered(int numFiltered) {
|
||||
Set<Integer> accepted = new HashSet<>();
|
||||
for (int i = 0; i < numFiltered; ) {
|
||||
int index = random().nextInt(numDocs);
|
||||
if (!accepted.contains(index)) {
|
||||
accepted.add(index);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return accepted.stream().mapToInt(Integer::intValue).toArray();
|
||||
}
|
||||
|
||||
private int getId(IndexSearcher searcher, int doc) throws IOException {
|
||||
return Objects.requireNonNull(searcher.storedFields().document(doc).getField(idField))
|
||||
.numericValue()
|
||||
.intValue();
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
V[] getRandomVectors(int numDocs, int dim) {
|
||||
return (V[]) IntStream.range(0, numDocs).mapToObj(i -> getRandomVector(dim)).toArray();
|
||||
}
|
||||
|
||||
@SafeVarargs
|
||||
final Directory getIndexStore(V... vectors) throws IOException {
|
||||
Directory dir = newDirectory();
|
||||
try (RandomIndexWriter writer = new RandomIndexWriter(random(), dir)) {
|
||||
for (int i = 0; i < vectors.length; ++i) {
|
||||
Document doc = new Document();
|
||||
doc.add(getVectorField(vectorField, vectors[i], function));
|
||||
doc.add(new IntField(idField, i, Field.Store.YES));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
}
|
||||
return dir;
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.apache.lucene.document.KnnByteVectorField;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
import org.junit.Before;
|
||||
|
||||
public class TestByteVectorSimilarityQuery
|
||||
extends BaseVectorSimilarityQueryTestCase<
|
||||
byte[], KnnByteVectorField, ByteVectorSimilarityQuery> {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
vectorField = getClass().getSimpleName() + ":VectorField";
|
||||
idField = getClass().getSimpleName() + ":IdField";
|
||||
function = VectorSimilarityFunction.EUCLIDEAN;
|
||||
numDocs = atLeast(100);
|
||||
dim = atLeast(50);
|
||||
}
|
||||
|
||||
@Override
|
||||
byte[] getRandomVector(int dim) {
|
||||
return TestVectorUtil.randomVectorBytes(dim);
|
||||
}
|
||||
|
||||
@Override
|
||||
float compare(byte[] vector1, byte[] vector2) {
|
||||
return function.compare(vector1, vector2);
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean checkEquals(byte[] vector1, byte[] vector2) {
|
||||
return Arrays.equals(vector1, vector2);
|
||||
}
|
||||
|
||||
@Override
|
||||
KnnByteVectorField getVectorField(String name, byte[] vector, VectorSimilarityFunction function) {
|
||||
return new KnnByteVectorField(name, vector, function);
|
||||
}
|
||||
|
||||
@Override
|
||||
ByteVectorSimilarityQuery getVectorQuery(
|
||||
String field,
|
||||
byte[] vector,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
return new ByteVectorSimilarityQuery(
|
||||
field, vector, traversalSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
ByteVectorSimilarityQuery getThrowingVectorQuery(
|
||||
String field,
|
||||
byte[] vector,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
return new ByteVectorSimilarityQuery(
|
||||
field, vector, traversalSimilarity, resultSimilarity, filter) {
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,86 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.util.Arrays;
|
||||
import org.apache.lucene.document.KnnFloatVectorField;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.util.TestVectorUtil;
|
||||
import org.junit.Before;
|
||||
|
||||
public class TestFloatVectorSimilarityQuery
|
||||
extends BaseVectorSimilarityQueryTestCase<
|
||||
float[], KnnFloatVectorField, FloatVectorSimilarityQuery> {
|
||||
|
||||
@Before
|
||||
public void setup() {
|
||||
vectorField = getClass().getSimpleName() + ":VectorField";
|
||||
idField = getClass().getSimpleName() + ":IdField";
|
||||
function = VectorSimilarityFunction.EUCLIDEAN;
|
||||
numDocs = atLeast(100);
|
||||
dim = atLeast(50);
|
||||
}
|
||||
|
||||
@Override
|
||||
float[] getRandomVector(int dim) {
|
||||
return TestVectorUtil.randomVector(dim);
|
||||
}
|
||||
|
||||
@Override
|
||||
float compare(float[] vector1, float[] vector2) {
|
||||
return function.compare(vector1, vector2);
|
||||
}
|
||||
|
||||
@Override
|
||||
boolean checkEquals(float[] vector1, float[] vector2) {
|
||||
return Arrays.equals(vector1, vector2);
|
||||
}
|
||||
|
||||
@Override
|
||||
KnnFloatVectorField getVectorField(
|
||||
String name, float[] vector, VectorSimilarityFunction function) {
|
||||
return new KnnFloatVectorField(name, vector, function);
|
||||
}
|
||||
|
||||
@Override
|
||||
FloatVectorSimilarityQuery getVectorQuery(
|
||||
String field,
|
||||
float[] vector,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
return new FloatVectorSimilarityQuery(
|
||||
field, vector, traversalSimilarity, resultSimilarity, filter);
|
||||
}
|
||||
|
||||
@Override
|
||||
FloatVectorSimilarityQuery getThrowingVectorQuery(
|
||||
String field,
|
||||
float[] vector,
|
||||
float traversalSimilarity,
|
||||
float resultSimilarity,
|
||||
Query filter) {
|
||||
return new FloatVectorSimilarityQuery(
|
||||
field, vector, traversalSimilarity, resultSimilarity, filter) {
|
||||
@Override
|
||||
VectorScorer createVectorScorer(LeafReaderContext context) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
||||
public class TestVectorSimilarityCollector extends LuceneTestCase {
|
||||
public void testResultCollection() {
|
||||
float traversalSimilarity = 0.3f, resultSimilarity = 0.5f;
|
||||
|
||||
VectorSimilarityCollector collector =
|
||||
new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, Integer.MAX_VALUE);
|
||||
int[] nodes = {1, 5, 10, 4, 8, 3, 2, 6, 7, 9};
|
||||
float[] scores = {0.1f, 0.2f, 0.3f, 0.5f, 0.2f, 0.6f, 0.9f, 0.3f, 0.7f, 0.8f};
|
||||
|
||||
float[] minCompetitiveSimilarities = new float[nodes.length];
|
||||
for (int i = 0; i < nodes.length; i++) {
|
||||
collector.collect(nodes[i], scores[i]);
|
||||
minCompetitiveSimilarities[i] = collector.minCompetitiveSimilarity();
|
||||
}
|
||||
|
||||
ScoreDoc[] scoreDocs = collector.topDocs().scoreDocs;
|
||||
int[] resultNodes = new int[scoreDocs.length];
|
||||
float[] resultScores = new float[scoreDocs.length];
|
||||
for (int i = 0; i < scoreDocs.length; i++) {
|
||||
resultNodes[i] = scoreDocs[i].doc;
|
||||
resultScores[i] = scoreDocs[i].score;
|
||||
}
|
||||
|
||||
// All nodes above resultSimilarity appear in order of collection
|
||||
assertArrayEquals(new int[] {4, 3, 2, 7, 9}, resultNodes);
|
||||
assertArrayEquals(new float[] {0.5f, 0.6f, 0.9f, 0.7f, 0.8f}, resultScores, 1e-3f);
|
||||
|
||||
// Min competitive similarity is minimum of traversalSimilarity or best result encountered
|
||||
assertArrayEquals(
|
||||
new float[] {0.1f, 0.2f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f, 0.3f},
|
||||
minCompetitiveSimilarities,
|
||||
1e-3f);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue