mirror of https://github.com/apache/lucene.git
LUCENE-9614: add KnnVectorQuery implementation
This commit is contained in:
parent
a9fb5a965d
commit
624560a3d7
|
@ -259,7 +259,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
|
|||
float score = results.topScore();
|
||||
results.pop();
|
||||
if (reversed) {
|
||||
score = (float) Math.exp(-score / target.length);
|
||||
score = 1 / (1 + score);
|
||||
}
|
||||
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.codecs.KnnVectorsReader;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
|
||||
/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
|
||||
public class KnnVectorQuery extends Query {
|
||||
|
||||
private static final TopDocs NO_RESULTS =
|
||||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
|
||||
|
||||
private final String field;
|
||||
private final float[] target;
|
||||
private final int k;
|
||||
|
||||
/**
|
||||
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
|
||||
* given field. <code>target</code> vector.
|
||||
*
|
||||
* @param field a field that has been indexed as a {@link KnnVectorField}.
|
||||
* @param target the target of the search
|
||||
* @param k the number of documents to find
|
||||
* @throws IllegalArgumentException if <code>k</code> is less than 1
|
||||
*/
|
||||
public KnnVectorQuery(String field, float[] target, int k) {
|
||||
this.field = field;
|
||||
this.target = target;
|
||||
this.k = k;
|
||||
if (k < 1) {
|
||||
throw new IllegalArgumentException("k must be at least 1, got: " + k);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexReader reader) throws IOException {
|
||||
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
|
||||
for (LeafReaderContext ctx : reader.leaves()) {
|
||||
perLeafResults[ctx.ord] = searchLeaf(ctx, Math.min(k, reader.numDocs()));
|
||||
}
|
||||
// Merge sort the results
|
||||
TopDocs topK = TopDocs.merge(k, perLeafResults);
|
||||
if (topK.scoreDocs.length == 0) {
|
||||
return new MatchNoDocsQuery();
|
||||
}
|
||||
return createRewrittenQuery(reader, topK);
|
||||
}
|
||||
|
||||
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
|
||||
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
|
||||
if (results == null) {
|
||||
return NO_RESULTS;
|
||||
}
|
||||
if (ctx.docBase > 0) {
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
scoreDoc.doc += ctx.docBase;
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
|
||||
int len = topK.scoreDocs.length;
|
||||
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
|
||||
int[] docs = new int[len];
|
||||
float[] scores = new float[len];
|
||||
for (int i = 0; i < len; i++) {
|
||||
docs[i] = topK.scoreDocs[i].doc;
|
||||
scores[i] = topK.scoreDocs[i].score;
|
||||
}
|
||||
int[] segmentStarts = findSegmentStarts(reader, docs);
|
||||
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.hashCode());
|
||||
}
|
||||
|
||||
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
|
||||
int[] starts = new int[reader.leaves().size() + 1];
|
||||
starts[starts.length - 1] = docs.length;
|
||||
if (starts.length == 2) {
|
||||
return starts;
|
||||
}
|
||||
int resultIndex = 0;
|
||||
for (int i = 1; i < starts.length - 1; i++) {
|
||||
int upper = reader.leaves().get(i).docBase;
|
||||
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
|
||||
if (resultIndex < 0) {
|
||||
resultIndex = -1 - resultIndex;
|
||||
}
|
||||
starts[i] = resultIndex;
|
||||
}
|
||||
return starts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
if (visitor.acceptField(field)) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
return obj instanceof KnnVectorQuery
|
||||
&& ((KnnVectorQuery) obj).k == k
|
||||
&& ((KnnVectorQuery) obj).field.equals(field)
|
||||
&& Arrays.equals(((KnnVectorQuery) obj).target, target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(field, k, Arrays.hashCode(target));
|
||||
}
|
||||
|
||||
/** Caches the results of a KnnVector search: a list of docs and their scores */
|
||||
static class DocAndScoreQuery extends Query {
|
||||
|
||||
private final int k;
|
||||
private final int[] docs;
|
||||
private final float[] scores;
|
||||
private final int[] segmentStarts;
|
||||
private final int readerHash;
|
||||
|
||||
/**
|
||||
* Constructor
|
||||
*
|
||||
* @param k the number of documents requested
|
||||
* @param docs the global docids of documents that match, in ascending order
|
||||
* @param scores the scores of the matching documents
|
||||
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
|
||||
* document in each segment. If a segment has no matching documents, it should be assigned
|
||||
* the index of the next segment that does. There should be a final entry that is always
|
||||
* docs.length-1.
|
||||
* @param readerHash a hash code identifying the IndexReader used to create this query
|
||||
*/
|
||||
DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, int readerHash) {
|
||||
this.k = k;
|
||||
this.docs = docs;
|
||||
this.scores = scores;
|
||||
this.segmentStarts = segmentStarts;
|
||||
this.readerHash = readerHash;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
if (searcher.getIndexReader().hashCode() != readerHash) {
|
||||
throw new IllegalStateException("This DocAndScore query was created by a different reader");
|
||||
}
|
||||
return new Weight(this) {
|
||||
@Override
|
||||
public Explanation explain(LeafReaderContext context, int doc) {
|
||||
int found = Arrays.binarySearch(docs, doc);
|
||||
if (found < 0) {
|
||||
return Explanation.noMatch("not in top " + k);
|
||||
}
|
||||
return Explanation.match(scores[found], "within top " + k);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Scorer scorer(LeafReaderContext context) {
|
||||
|
||||
return new Scorer(this) {
|
||||
final int lower = segmentStarts[context.ord];
|
||||
final int upper = segmentStarts[context.ord + 1];
|
||||
int upTo = -1;
|
||||
|
||||
@Override
|
||||
public DocIdSetIterator iterator() {
|
||||
return new DocIdSetIterator() {
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int nextDoc() {
|
||||
if (upTo == -1) {
|
||||
upTo = lower;
|
||||
} else {
|
||||
++upTo;
|
||||
}
|
||||
return docIdNoShadow();
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advance(int target) throws IOException {
|
||||
return slowAdvance(target);
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return upper - lower;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public float getMaxScore(int docid) {
|
||||
docid += context.docBase;
|
||||
float maxScore = 0;
|
||||
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
|
||||
maxScore = Math.max(maxScore, scores[idx]);
|
||||
}
|
||||
return maxScore;
|
||||
}
|
||||
|
||||
@Override
|
||||
public float score() {
|
||||
return scores[upTo];
|
||||
}
|
||||
|
||||
@Override
|
||||
public int advanceShallow(int docid) {
|
||||
int start = Math.max(upTo, lower);
|
||||
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
|
||||
if (docidIndex < 0) {
|
||||
docidIndex = -1 - docidIndex;
|
||||
}
|
||||
if (docidIndex >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[docidIndex];
|
||||
}
|
||||
|
||||
/**
|
||||
* move the implementation of docID() into a differently-named method so we can call it
|
||||
* from DocIDSetIterator.docID() even though this class is anonymous
|
||||
*
|
||||
* @return the current docid
|
||||
*/
|
||||
private int docIdNoShadow() {
|
||||
if (upTo == -1) {
|
||||
return -1;
|
||||
}
|
||||
if (upTo >= upper) {
|
||||
return NO_MORE_DOCS;
|
||||
}
|
||||
return docs[upTo] - context.docBase;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int docID() {
|
||||
return docIdNoShadow();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isCacheable(LeafReaderContext ctx) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return "DocAndScore[" + k + "]";
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
visitor.visitLeaf(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
if (obj instanceof DocAndScoreQuery == false) {
|
||||
return false;
|
||||
}
|
||||
return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
|
||||
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(
|
||||
DocAndScoreQuery.class.hashCode(), Arrays.hashCode(docs), Arrays.hashCode(scores));
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,324 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
|
||||
import static org.apache.lucene.util.TestVectorUtil.randomVector;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.document.Field;
|
||||
import org.apache.lucene.document.KnnVectorField;
|
||||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.index.DirectoryReader;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.IndexWriter;
|
||||
import org.apache.lucene.index.IndexWriterConfig;
|
||||
import org.apache.lucene.index.RandomIndexWriter;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
|
||||
/** TestKnnVectorQuery tests KnnVectorQuery. */
|
||||
public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
|
||||
public void testEquals() {
|
||||
KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
|
||||
|
||||
assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10));
|
||||
|
||||
assertNotEquals(null, q1);
|
||||
|
||||
assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
|
||||
|
||||
assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10));
|
||||
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10));
|
||||
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2));
|
||||
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10));
|
||||
}
|
||||
|
||||
public void testToString() {
|
||||
KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
|
||||
assertEquals("<vector:f1[0.0,...][10]>", q1.toString("ignored"));
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to
|
||||
* match.
|
||||
*/
|
||||
public void testEmptyIndex() throws IOException {
|
||||
try (Directory indexStore = getIndexStore("field");
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10);
|
||||
assertMatches(searcher, kvq, 0);
|
||||
Query q = searcher.rewrite(kvq);
|
||||
assertTrue(q instanceof MatchNoDocsQuery);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tests that a KnnVectorQuery whose topK >= numDocs returns all the documents in score order
|
||||
*/
|
||||
public void testFindAll() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
|
||||
assertMatches(searcher, kvq, reader.numDocs());
|
||||
TopDocs topDocs = searcher.search(kvq, 3);
|
||||
assertEquals(2, topDocs.scoreDocs[0].doc);
|
||||
assertEquals(0, topDocs.scoreDocs[1].doc);
|
||||
assertEquals(1, topDocs.scoreDocs[2].doc);
|
||||
}
|
||||
}
|
||||
|
||||
/** testDimensionMismatch */
|
||||
public void testDimensionMismatch() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
|
||||
IllegalArgumentException e =
|
||||
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
|
||||
assertEquals("vector dimensions differ: 1!=2", e.getMessage());
|
||||
}
|
||||
}
|
||||
|
||||
/** testNonVectorField */
|
||||
public void testNonVectorField() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
|
||||
assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0);
|
||||
}
|
||||
}
|
||||
|
||||
/** Test bad parameters */
|
||||
public void testIllegalArguments() throws IOException {
|
||||
expectThrows(
|
||||
IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0));
|
||||
}
|
||||
|
||||
public void testDifferentReader() throws IOException {
|
||||
try (Directory indexStore =
|
||||
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
|
||||
IndexReader reader = DirectoryReader.open(indexStore)) {
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query dasq = query.rewrite(reader);
|
||||
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
|
||||
expectThrows(
|
||||
IllegalStateException.class,
|
||||
() -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
|
||||
}
|
||||
}
|
||||
|
||||
public void testAdvanceShallow() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query dasq = query.rewrite(reader);
|
||||
Scorer scorer =
|
||||
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
|
||||
// before advancing the iterator
|
||||
assertEquals(1, scorer.advanceShallow(0));
|
||||
assertEquals(1, scorer.advanceShallow(1));
|
||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
||||
|
||||
// after advancing the iterator
|
||||
scorer.iterator().advance(2);
|
||||
assertEquals(2, scorer.advanceShallow(0));
|
||||
assertEquals(2, scorer.advanceShallow(2));
|
||||
assertEquals(3, scorer.advanceShallow(3));
|
||||
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testScore() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
assertEquals(1, reader.leaves().size());
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Query rewritten = query.rewrite(reader);
|
||||
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
|
||||
Scorer scorer = weight.scorer(reader.leaves().get(0));
|
||||
|
||||
// prior to advancing, score is 0
|
||||
assertEquals(-1, scorer.docID());
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
|
||||
|
||||
// test getMaxScore
|
||||
assertEquals(0, scorer.getMaxScore(-1), 0);
|
||||
assertEquals(0, scorer.getMaxScore(0), 0);
|
||||
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
|
||||
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
|
||||
|
||||
DocIdSetIterator it = scorer.iterator();
|
||||
assertEquals(3, it.cost());
|
||||
assertEquals(1, it.nextDoc());
|
||||
assertEquals(1 / 6f, scorer.score(), 0);
|
||||
assertEquals(3, it.advance(3));
|
||||
assertEquals(1 / 2f, scorer.score(), 0);
|
||||
assertEquals(NO_MORE_DOCS, it.advance(4));
|
||||
expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void testExplain() throws IOException {
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", new float[] {j, j}));
|
||||
w.addDocument(doc);
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
|
||||
Explanation matched = searcher.explain(query, 2);
|
||||
assertTrue(matched.isMatch());
|
||||
assertEquals(1 / 2f, matched.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("within top 3", matched.getDescription());
|
||||
|
||||
Explanation nomatch = searcher.explain(query, 4);
|
||||
assertFalse(nomatch.isMatch());
|
||||
assertEquals(0f, nomatch.getValue());
|
||||
assertEquals(0, matched.getDetails().length);
|
||||
assertEquals("not in top 3", nomatch.getDescription());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Test that when vectors are abnormally distributed among segments, we still find the top K */
|
||||
public void testSkewedIndex() throws IOException {
|
||||
/* We have to choose the numbers carefully here so that some segment has more than the expected
|
||||
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
|
||||
* randomly fail to find one).
|
||||
*/
|
||||
try (Directory d = newDirectory()) {
|
||||
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
|
||||
int r = 0;
|
||||
for (int i = 0; i < 5; i++) {
|
||||
for (int j = 0; j < 5; j++) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField("field", new float[] {r, r}));
|
||||
w.addDocument(doc);
|
||||
++r;
|
||||
}
|
||||
w.flush();
|
||||
}
|
||||
}
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10);
|
||||
assertEquals(8, results.scoreDocs.length);
|
||||
assertEquals(0, results.scoreDocs[0].doc);
|
||||
assertEquals(7, results.scoreDocs[7].doc);
|
||||
|
||||
// test some results in the middle of the sequence - also tests docid tiebreaking
|
||||
results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10);
|
||||
assertEquals(8, results.scoreDocs.length);
|
||||
assertEquals(10, results.scoreDocs[0].doc);
|
||||
assertEquals(6, results.scoreDocs[7].doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
|
||||
public void testRandom() throws IOException {
|
||||
int numDocs = atLeast(100);
|
||||
int dimension = atLeast(5);
|
||||
int numIters = atLeast(10);
|
||||
boolean everyDocHasAVector = random().nextBoolean();
|
||||
try (Directory d = newDirectory()) {
|
||||
RandomIndexWriter w = new RandomIndexWriter(random(), d);
|
||||
for (int i = 0; i < numDocs; i++) {
|
||||
Document doc = new Document();
|
||||
if (everyDocHasAVector || random().nextInt(10) != 2) {
|
||||
doc.add(new KnnVectorField("field", randomVector(dimension)));
|
||||
}
|
||||
w.addDocument(doc);
|
||||
}
|
||||
w.close();
|
||||
try (IndexReader reader = DirectoryReader.open(d)) {
|
||||
IndexSearcher searcher = newSearcher(reader);
|
||||
for (int i = 0; i < numIters; i++) {
|
||||
int k = random().nextInt(100) + 1;
|
||||
KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k);
|
||||
int n = random().nextInt(100) + 1;
|
||||
TopDocs results = searcher.search(query, n);
|
||||
int expected = Math.min(Math.min(n, k), reader.numDocs());
|
||||
// we may get fewer results than requested if there are deletions, but this test doesn't
|
||||
// test that
|
||||
assert reader.hasDeletions() == false;
|
||||
assertEquals(expected, results.scoreDocs.length);
|
||||
assertTrue(results.totalHits.value >= results.scoreDocs.length);
|
||||
// verify the results are in descending score order
|
||||
float last = Float.MAX_VALUE;
|
||||
for (ScoreDoc scoreDoc : results.scoreDocs) {
|
||||
assertTrue(scoreDoc.score <= last);
|
||||
last = scoreDoc.score;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private Directory getIndexStore(String field, float[]... contents) throws IOException {
|
||||
Directory indexStore = newDirectory();
|
||||
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
|
||||
for (int i = 0; i < contents.length; ++i) {
|
||||
Document doc = new Document();
|
||||
doc.add(new KnnVectorField(field, contents[i]));
|
||||
doc.add(new StringField("id", "id" + i, Field.Store.NO));
|
||||
writer.addDocument(doc);
|
||||
}
|
||||
writer.close();
|
||||
return indexStore;
|
||||
}
|
||||
|
||||
private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
|
||||
throws IOException {
|
||||
ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
|
||||
assertEquals(expectedMatches, result.length);
|
||||
}
|
||||
}
|
|
@ -16,6 +16,8 @@
|
|||
*/
|
||||
package org.apache.lucene.util;
|
||||
|
||||
import java.util.Random;
|
||||
|
||||
public class TestVectorUtil extends LuceneTestCase {
|
||||
|
||||
public static final double DELTA = 1e-4;
|
||||
|
@ -81,7 +83,7 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v));
|
||||
}
|
||||
|
||||
private float l2(float[] v) {
|
||||
private static float l2(float[] v) {
|
||||
float l2 = 0;
|
||||
for (float x : v) {
|
||||
l2 += x * x;
|
||||
|
@ -89,7 +91,7 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
return l2;
|
||||
}
|
||||
|
||||
private float[] negative(float[] v) {
|
||||
private static float[] negative(float[] v) {
|
||||
float[] u = new float[v.length];
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
u[i] = -v[i];
|
||||
|
@ -97,10 +99,15 @@ public class TestVectorUtil extends LuceneTestCase {
|
|||
return u;
|
||||
}
|
||||
|
||||
private float[] randomVector() {
|
||||
float[] v = new float[random().nextInt(100) + 1];
|
||||
for (int i = 0; i < v.length; i++) {
|
||||
v[i] = random().nextFloat();
|
||||
private static float[] randomVector() {
|
||||
return randomVector(random().nextInt(100) + 1);
|
||||
}
|
||||
|
||||
public static float[] randomVector(int dim) {
|
||||
float[] v = new float[dim];
|
||||
Random random = random();
|
||||
for (int i = 0; i < dim; i++) {
|
||||
v[i] = random.nextFloat();
|
||||
}
|
||||
return v;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue