LUCENE-9614: add KnnVectorQuery implementation

This commit is contained in:
Michael Sokolov 2021-08-13 12:15:40 -04:00 committed by GitHub
parent a9fb5a965d
commit 624560a3d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 645 additions and 7 deletions

View File

@ -259,7 +259,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
float score = results.topScore();
results.pop();
if (reversed) {
score = (float) Math.exp(-score / target.length);
score = 1 / (1 + score);
}
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
}

View File

@ -0,0 +1,307 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.Objects;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
public class KnnVectorQuery extends Query {
private static final TopDocs NO_RESULTS =
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
private final String field;
private final float[] target;
private final int k;
/**
* Find the <code>k</code> nearest documents to the target vector according to the vectors in the
* given field. <code>target</code> vector.
*
* @param field a field that has been indexed as a {@link KnnVectorField}.
* @param target the target of the search
* @param k the number of documents to find
* @throws IllegalArgumentException if <code>k</code> is less than 1
*/
public KnnVectorQuery(String field, float[] target, int k) {
this.field = field;
this.target = target;
this.k = k;
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1, got: " + k);
}
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
for (LeafReaderContext ctx : reader.leaves()) {
perLeafResults[ctx.ord] = searchLeaf(ctx, Math.min(k, reader.numDocs()));
}
// Merge sort the results
TopDocs topK = TopDocs.merge(k, perLeafResults);
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery();
}
return createRewrittenQuery(reader, topK);
}
private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
if (results == null) {
return NO_RESULTS;
}
if (ctx.docBase > 0) {
for (ScoreDoc scoreDoc : results.scoreDocs) {
scoreDoc.doc += ctx.docBase;
}
}
return results;
}
private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
int len = topK.scoreDocs.length;
Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
float[] scores = new float[len];
for (int i = 0; i < len; i++) {
docs[i] = topK.scoreDocs[i].doc;
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.hashCode());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
int[] starts = new int[reader.leaves().size() + 1];
starts[starts.length - 1] = docs.length;
if (starts.length == 2) {
return starts;
}
int resultIndex = 0;
for (int i = 1; i < starts.length - 1; i++) {
int upper = reader.leaves().get(i).docBase;
resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
if (resultIndex < 0) {
resultIndex = -1 - resultIndex;
}
starts[i] = resultIndex;
}
return starts;
}
@Override
public String toString(String field) {
return "<vector:" + this.field + "[" + target[0] + ",...][" + k + "]>";
}
@Override
public void visit(QueryVisitor visitor) {
if (visitor.acceptField(field)) {
visitor.visitLeaf(this);
}
}
@Override
public boolean equals(Object obj) {
return obj instanceof KnnVectorQuery
&& ((KnnVectorQuery) obj).k == k
&& ((KnnVectorQuery) obj).field.equals(field)
&& Arrays.equals(((KnnVectorQuery) obj).target, target);
}
@Override
public int hashCode() {
return Objects.hash(field, k, Arrays.hashCode(target));
}
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
private final int readerHash;
/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
* document in each segment. If a segment has no matching documents, it should be assigned
* the index of the next segment that does. There should be a final entry that is always
* docs.length-1.
* @param readerHash a hash code identifying the IndexReader used to create this query
*/
DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, int readerHash) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.readerHash = readerHash;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
if (searcher.getIndexReader().hashCode() != readerHash) {
throw new IllegalStateException("This DocAndScore query was created by a different reader");
}
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found], "within top " + k);
}
@Override
public Scorer scorer(LeafReaderContext context) {
return new Scorer(this) {
final int lower = segmentStarts[context.ord];
final int upper = segmentStarts[context.ord + 1];
int upTo = -1;
@Override
public DocIdSetIterator iterator() {
return new DocIdSetIterator() {
@Override
public int docID() {
return docIdNoShadow();
}
@Override
public int nextDoc() {
if (upTo == -1) {
upTo = lower;
} else {
++upTo;
}
return docIdNoShadow();
}
@Override
public int advance(int target) throws IOException {
return slowAdvance(target);
}
@Override
public long cost() {
return upper - lower;
}
};
}
@Override
public float getMaxScore(int docid) {
docid += context.docBase;
float maxScore = 0;
for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
maxScore = Math.max(maxScore, scores[idx]);
}
return maxScore;
}
@Override
public float score() {
return scores[upTo];
}
@Override
public int advanceShallow(int docid) {
int start = Math.max(upTo, lower);
int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
if (docidIndex < 0) {
docidIndex = -1 - docidIndex;
}
if (docidIndex >= upper) {
return NO_MORE_DOCS;
}
return docs[docidIndex];
}
/**
* move the implementation of docID() into a differently-named method so we can call it
* from DocIDSetIterator.docID() even though this class is anonymous
*
* @return the current docid
*/
private int docIdNoShadow() {
if (upTo == -1) {
return -1;
}
if (upTo >= upper) {
return NO_MORE_DOCS;
}
return docs[upTo] - context.docBase;
}
@Override
public int docID() {
return docIdNoShadow();
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
@Override
public boolean equals(Object obj) {
if (obj instanceof DocAndScoreQuery == false) {
return false;
}
return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
&& Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
}
@Override
public int hashCode() {
return Objects.hash(
DocAndScoreQuery.class.hashCode(), Arrays.hashCode(docs), Arrays.hashCode(scores));
}
}
}

View File

@ -0,0 +1,324 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import static org.apache.lucene.util.TestVectorUtil.randomVector;
import java.io.IOException;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
/** TestKnnVectorQuery tests KnnVectorQuery. */
public class TestKnnVectorQuery extends LuceneTestCase {
public void testEquals() {
KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10));
assertNotEquals(null, q1);
assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10));
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10));
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2));
assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10));
}
public void testToString() {
KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
assertEquals("<vector:f1[0.0,...][10]>", q1.toString("ignored"));
}
/**
* Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to
* match.
*/
public void testEmptyIndex() throws IOException {
try (Directory indexStore = getIndexStore("field");
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10);
assertMatches(searcher, kvq, 0);
Query q = searcher.rewrite(kvq);
assertTrue(q instanceof MatchNoDocsQuery);
}
}
/**
* Tests that a KnnVectorQuery whose topK &gt;= numDocs returns all the documents in score order
*/
public void testFindAll() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
assertMatches(searcher, kvq, reader.numDocs());
TopDocs topDocs = searcher.search(kvq, 3);
assertEquals(2, topDocs.scoreDocs[0].doc);
assertEquals(0, topDocs.scoreDocs[1].doc);
assertEquals(1, topDocs.scoreDocs[2].doc);
}
}
/** testDimensionMismatch */
public void testDimensionMismatch() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
IllegalArgumentException e =
expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
assertEquals("vector dimensions differ: 1!=2", e.getMessage());
}
}
/** testNonVectorField */
public void testNonVectorField() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
IndexSearcher searcher = newSearcher(reader);
assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0);
}
}
/** Test bad parameters */
public void testIllegalArguments() throws IOException {
expectThrows(
IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0));
}
public void testDifferentReader() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(reader);
IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
expectThrows(
IllegalStateException.class,
() -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
}
}
public void testAdvanceShallow() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query dasq = query.rewrite(reader);
Scorer scorer =
dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
// before advancing the iterator
assertEquals(1, scorer.advanceShallow(0));
assertEquals(1, scorer.advanceShallow(1));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
// after advancing the iterator
scorer.iterator().advance(2);
assertEquals(2, scorer.advanceShallow(0));
assertEquals(2, scorer.advanceShallow(2));
assertEquals(3, scorer.advanceShallow(3));
assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
}
}
}
public void testScore() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
assertEquals(1, reader.leaves().size());
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Query rewritten = query.rewrite(reader);
Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
Scorer scorer = weight.scorer(reader.leaves().get(0));
// prior to advancing, score is 0
assertEquals(-1, scorer.docID());
expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
// test getMaxScore
assertEquals(0, scorer.getMaxScore(-1), 0);
assertEquals(0, scorer.getMaxScore(0), 0);
// This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
DocIdSetIterator it = scorer.iterator();
assertEquals(3, it.cost());
assertEquals(1, it.nextDoc());
assertEquals(1 / 6f, scorer.score(), 0);
assertEquals(3, it.advance(3));
assertEquals(1 / 2f, scorer.score(), 0);
assertEquals(NO_MORE_DOCS, it.advance(4));
expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
}
}
}
public void testExplain() throws IOException {
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", new float[] {j, j}));
w.addDocument(doc);
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = new IndexSearcher(reader);
KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
Explanation matched = searcher.explain(query, 2);
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription());
}
}
}
/** Test that when vectors are abnormally distributed among segments, we still find the top K */
public void testSkewedIndex() throws IOException {
/* We have to choose the numbers carefully here so that some segment has more than the expected
* number of top K documents, but no more than K documents in total (otherwise we might occasionally
* randomly fail to find one).
*/
try (Directory d = newDirectory()) {
try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
int r = 0;
for (int i = 0; i < 5; i++) {
for (int j = 0; j < 5; j++) {
Document doc = new Document();
doc.add(new KnnVectorField("field", new float[] {r, r}));
w.addDocument(doc);
++r;
}
w.flush();
}
}
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertEquals(0, results.scoreDocs[0].doc);
assertEquals(7, results.scoreDocs[7].doc);
// test some results in the middle of the sequence - also tests docid tiebreaking
results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10);
assertEquals(8, results.scoreDocs.length);
assertEquals(10, results.scoreDocs[0].doc);
assertEquals(6, results.scoreDocs[7].doc);
}
}
}
/** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
public void testRandom() throws IOException {
int numDocs = atLeast(100);
int dimension = atLeast(5);
int numIters = atLeast(10);
boolean everyDocHasAVector = random().nextBoolean();
try (Directory d = newDirectory()) {
RandomIndexWriter w = new RandomIndexWriter(random(), d);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
if (everyDocHasAVector || random().nextInt(10) != 2) {
doc.add(new KnnVectorField("field", randomVector(dimension)));
}
w.addDocument(doc);
}
w.close();
try (IndexReader reader = DirectoryReader.open(d)) {
IndexSearcher searcher = newSearcher(reader);
for (int i = 0; i < numIters; i++) {
int k = random().nextInt(100) + 1;
KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k);
int n = random().nextInt(100) + 1;
TopDocs results = searcher.search(query, n);
int expected = Math.min(Math.min(n, k), reader.numDocs());
// we may get fewer results than requested if there are deletions, but this test doesn't
// test that
assert reader.hasDeletions() == false;
assertEquals(expected, results.scoreDocs.length);
assertTrue(results.totalHits.value >= results.scoreDocs.length);
// verify the results are in descending score order
float last = Float.MAX_VALUE;
for (ScoreDoc scoreDoc : results.scoreDocs) {
assertTrue(scoreDoc.score <= last);
last = scoreDoc.score;
}
}
}
}
}
private Directory getIndexStore(String field, float[]... contents) throws IOException {
Directory indexStore = newDirectory();
RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
for (int i = 0; i < contents.length; ++i) {
Document doc = new Document();
doc.add(new KnnVectorField(field, contents[i]));
doc.add(new StringField("id", "id" + i, Field.Store.NO));
writer.addDocument(doc);
}
writer.close();
return indexStore;
}
private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
throws IOException {
ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
assertEquals(expectedMatches, result.length);
}
}

View File

@ -16,6 +16,8 @@
*/
package org.apache.lucene.util;
import java.util.Random;
public class TestVectorUtil extends LuceneTestCase {
public static final double DELTA = 1e-4;
@ -81,7 +83,7 @@ public class TestVectorUtil extends LuceneTestCase {
expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v));
}
private float l2(float[] v) {
private static float l2(float[] v) {
float l2 = 0;
for (float x : v) {
l2 += x * x;
@ -89,7 +91,7 @@ public class TestVectorUtil extends LuceneTestCase {
return l2;
}
private float[] negative(float[] v) {
private static float[] negative(float[] v) {
float[] u = new float[v.length];
for (int i = 0; i < v.length; i++) {
u[i] = -v[i];
@ -97,10 +99,15 @@ public class TestVectorUtil extends LuceneTestCase {
return u;
}
private float[] randomVector() {
float[] v = new float[random().nextInt(100) + 1];
for (int i = 0; i < v.length; i++) {
v[i] = random().nextFloat();
private static float[] randomVector() {
return randomVector(random().nextInt(100) + 1);
}
public static float[] randomVector(int dim) {
float[] v = new float[dim];
Random random = random();
for (int i = 0; i < dim; i++) {
v[i] = random.nextFloat();
}
return v;
}