LUCENE-10063: implement SimpleTextKnnvectorsReader.search

This commit is contained in:
Michael Sokolov 2021-08-31 13:55:13 -04:00 committed by GitHub
parent 6ade29c71a
commit 9c7f0d45ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 42 additions and 3 deletions

View File

@ -31,8 +31,13 @@ import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.RandomAccessVectorValues;
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.index.VectorValues;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.store.BufferedChecksumIndexInput;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.IOContext;
@ -140,7 +145,33 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
@Override
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
throw new UnsupportedOperationException();
VectorValues values = getVectorValues(field);
if (target.length != values.dimension()) {
throw new IllegalArgumentException(
"incorrect dimension for field "
+ field
+ "; expected "
+ values.dimension()
+ " but target has "
+ target.length);
}
FieldInfo info = readState.fieldInfos.fieldInfo(field);
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
HitQueue topK = new HitQueue(k, false);
int doc;
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
float[] vector = values.vectorValue();
float score = vectorSimilarity.compare(vector, target);
if (vectorSimilarity.reversed) {
score = 1 / (score + 1);
}
topK.insertWithOverflow(new ScoreDoc(doc, score));
}
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
topScoreDocs[i] = topK.pop();
}
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs);
}
@Override

View File

@ -35,10 +35,8 @@ import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.LuceneTestCase;
import org.apache.lucene.util.LuceneTestCase.SuppressCodecs;
/** TestKnnVectorQuery tests KnnVectorQuery. */
@SuppressCodecs("SimpleText") // The codec must support kNN searches
public class TestKnnVectorQuery extends LuceneTestCase {
public void testEquals() {

View File

@ -31,6 +31,7 @@ import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.store.Directory;
import org.apache.lucene.store.FSDirectory;
import org.apache.lucene.util.Bits;
@ -856,6 +857,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
}
}
}
// assert that searchNearestVectors returns the expected number of documents, in
// descending score order
int k = random().nextInt(numDoc / 2);
TopDocs results =
ctx.reader().searchNearestVectors(fieldName, randomVector(dimension), k, liveDocs);
assertEquals(k, results.scoreDocs.length);
for (int i = 0; i < k - 1; i++) {
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
}
}
}
}