mirror of https://github.com/apache/lucene.git
LUCENE-10063: implement SimpleTextKnnvectorsReader.search
This commit is contained in:
parent
6ade29c71a
commit
9c7f0d45ee
|
@ -31,8 +31,13 @@ import org.apache.lucene.index.IndexFileNames;
|
|||
import org.apache.lucene.index.RandomAccessVectorValues;
|
||||
import org.apache.lucene.index.RandomAccessVectorValuesProducer;
|
||||
import org.apache.lucene.index.SegmentReadState;
|
||||
import org.apache.lucene.index.VectorSimilarityFunction;
|
||||
import org.apache.lucene.index.VectorValues;
|
||||
import org.apache.lucene.search.DocIdSetIterator;
|
||||
import org.apache.lucene.search.HitQueue;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.search.TotalHits;
|
||||
import org.apache.lucene.store.BufferedChecksumIndexInput;
|
||||
import org.apache.lucene.store.ChecksumIndexInput;
|
||||
import org.apache.lucene.store.IOContext;
|
||||
|
@ -140,7 +145,33 @@ public class SimpleTextKnnVectorsReader extends KnnVectorsReader {
|
|||
|
||||
@Override
|
||||
public TopDocs search(String field, float[] target, int k, Bits acceptDocs) throws IOException {
|
||||
throw new UnsupportedOperationException();
|
||||
VectorValues values = getVectorValues(field);
|
||||
if (target.length != values.dimension()) {
|
||||
throw new IllegalArgumentException(
|
||||
"incorrect dimension for field "
|
||||
+ field
|
||||
+ "; expected "
|
||||
+ values.dimension()
|
||||
+ " but target has "
|
||||
+ target.length);
|
||||
}
|
||||
FieldInfo info = readState.fieldInfos.fieldInfo(field);
|
||||
VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction();
|
||||
HitQueue topK = new HitQueue(k, false);
|
||||
int doc;
|
||||
while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
|
||||
float[] vector = values.vectorValue();
|
||||
float score = vectorSimilarity.compare(vector, target);
|
||||
if (vectorSimilarity.reversed) {
|
||||
score = 1 / (score + 1);
|
||||
}
|
||||
topK.insertWithOverflow(new ScoreDoc(doc, score));
|
||||
}
|
||||
ScoreDoc[] topScoreDocs = new ScoreDoc[topK.size()];
|
||||
for (int i = topScoreDocs.length - 1; i >= 0; i--) {
|
||||
topScoreDocs[i] = topK.pop();
|
||||
}
|
||||
return new TopDocs(new TotalHits(values.size(), TotalHits.Relation.EQUAL_TO), topScoreDocs);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
|
@ -35,10 +35,8 @@ import org.apache.lucene.index.RandomIndexWriter;
|
|||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.util.LuceneTestCase;
|
||||
import org.apache.lucene.util.LuceneTestCase.SuppressCodecs;
|
||||
|
||||
/** TestKnnVectorQuery tests KnnVectorQuery. */
|
||||
@SuppressCodecs("SimpleText") // The codec must support kNN searches
|
||||
public class TestKnnVectorQuery extends LuceneTestCase {
|
||||
|
||||
public void testEquals() {
|
||||
|
|
|
@ -31,6 +31,7 @@ import org.apache.lucene.document.NumericDocValuesField;
|
|||
import org.apache.lucene.document.StringField;
|
||||
import org.apache.lucene.search.Sort;
|
||||
import org.apache.lucene.search.SortField;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.store.FSDirectory;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
@ -856,6 +857,15 @@ public abstract class BaseKnnVectorsFormatTestCase extends BaseIndexFileFormatTe
|
|||
}
|
||||
}
|
||||
}
|
||||
// assert that searchNearestVectors returns the expected number of documents, in
|
||||
// descending score order
|
||||
int k = random().nextInt(numDoc / 2);
|
||||
TopDocs results =
|
||||
ctx.reader().searchNearestVectors(fieldName, randomVector(dimension), k, liveDocs);
|
||||
assertEquals(k, results.scoreDocs.length);
|
||||
for (int i = 0; i < k - 1; i++) {
|
||||
assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue