diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
index 8244026060c..6a69ab9f4e2 100644
--- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
+++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/Lucene90HnswVectorsReader.java
@@ -259,7 +259,7 @@ public final class Lucene90HnswVectorsReader extends KnnVectorsReader {
float score = results.topScore();
results.pop();
if (reversed) {
- score = (float) Math.exp(-score / target.length);
+ score = 1 / (1 + score);
}
scoreDocs[scoreDocs.length - ++i] = new ScoreDoc(fieldEntry.ordToDoc[node], score);
}
diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
new file mode 100644
index 00000000000..5dccb8042e4
--- /dev/null
+++ b/lucene/core/src/java/org/apache/lucene/search/KnnVectorQuery.java
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.Objects;
+import org.apache.lucene.codecs.KnnVectorsReader;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+
+/** Uses {@link KnnVectorsReader#search} to perform nearest Neighbour search. */
+public class KnnVectorQuery extends Query {
+
+ private static final TopDocs NO_RESULTS =
+ new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]);
+
+ private final String field;
+ private final float[] target;
+ private final int k;
+
+ /**
+ * Find the k
nearest documents to the target vector according to the vectors in the
+ * given field. target
vector.
+ *
+ * @param field a field that has been indexed as a {@link KnnVectorField}.
+ * @param target the target of the search
+ * @param k the number of documents to find
+ * @throws IllegalArgumentException if k
is less than 1
+ */
+ public KnnVectorQuery(String field, float[] target, int k) {
+ this.field = field;
+ this.target = target;
+ this.k = k;
+ if (k < 1) {
+ throw new IllegalArgumentException("k must be at least 1, got: " + k);
+ }
+ }
+
+ @Override
+ public Query rewrite(IndexReader reader) throws IOException {
+ TopDocs[] perLeafResults = new TopDocs[reader.leaves().size()];
+ for (LeafReaderContext ctx : reader.leaves()) {
+ perLeafResults[ctx.ord] = searchLeaf(ctx, Math.min(k, reader.numDocs()));
+ }
+ // Merge sort the results
+ TopDocs topK = TopDocs.merge(k, perLeafResults);
+ if (topK.scoreDocs.length == 0) {
+ return new MatchNoDocsQuery();
+ }
+ return createRewrittenQuery(reader, topK);
+ }
+
+ private TopDocs searchLeaf(LeafReaderContext ctx, int kPerLeaf) throws IOException {
+ TopDocs results = ctx.reader().searchNearestVectors(field, target, kPerLeaf);
+ if (results == null) {
+ return NO_RESULTS;
+ }
+ if (ctx.docBase > 0) {
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ scoreDoc.doc += ctx.docBase;
+ }
+ }
+ return results;
+ }
+
+ private Query createRewrittenQuery(IndexReader reader, TopDocs topK) {
+ int len = topK.scoreDocs.length;
+ Arrays.sort(topK.scoreDocs, Comparator.comparingInt(a -> a.doc));
+ int[] docs = new int[len];
+ float[] scores = new float[len];
+ for (int i = 0; i < len; i++) {
+ docs[i] = topK.scoreDocs[i].doc;
+ scores[i] = topK.scoreDocs[i].score;
+ }
+ int[] segmentStarts = findSegmentStarts(reader, docs);
+ return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.hashCode());
+ }
+
+ private int[] findSegmentStarts(IndexReader reader, int[] docs) {
+ int[] starts = new int[reader.leaves().size() + 1];
+ starts[starts.length - 1] = docs.length;
+ if (starts.length == 2) {
+ return starts;
+ }
+ int resultIndex = 0;
+ for (int i = 1; i < starts.length - 1; i++) {
+ int upper = reader.leaves().get(i).docBase;
+ resultIndex = Arrays.binarySearch(docs, resultIndex, docs.length, upper);
+ if (resultIndex < 0) {
+ resultIndex = -1 - resultIndex;
+ }
+ starts[i] = resultIndex;
+ }
+ return starts;
+ }
+
+ @Override
+ public String toString(String field) {
+ return "";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ if (visitor.acceptField(field)) {
+ visitor.visitLeaf(this);
+ }
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj instanceof KnnVectorQuery
+ && ((KnnVectorQuery) obj).k == k
+ && ((KnnVectorQuery) obj).field.equals(field)
+ && Arrays.equals(((KnnVectorQuery) obj).target, target);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(field, k, Arrays.hashCode(target));
+ }
+
+ /** Caches the results of a KnnVector search: a list of docs and their scores */
+ static class DocAndScoreQuery extends Query {
+
+ private final int k;
+ private final int[] docs;
+ private final float[] scores;
+ private final int[] segmentStarts;
+ private final int readerHash;
+
+ /**
+ * Constructor
+ *
+ * @param k the number of documents requested
+ * @param docs the global docids of documents that match, in ascending order
+ * @param scores the scores of the matching documents
+ * @param segmentStarts the indexes in docs and scores corresponding to the first matching
+ * document in each segment. If a segment has no matching documents, it should be assigned
+ * the index of the next segment that does. There should be a final entry that is always
+ * docs.length-1.
+ * @param readerHash a hash code identifying the IndexReader used to create this query
+ */
+ DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, int readerHash) {
+ this.k = k;
+ this.docs = docs;
+ this.scores = scores;
+ this.segmentStarts = segmentStarts;
+ this.readerHash = readerHash;
+ }
+
+ @Override
+ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
+ throws IOException {
+ if (searcher.getIndexReader().hashCode() != readerHash) {
+ throw new IllegalStateException("This DocAndScore query was created by a different reader");
+ }
+ return new Weight(this) {
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc) {
+ int found = Arrays.binarySearch(docs, doc);
+ if (found < 0) {
+ return Explanation.noMatch("not in top " + k);
+ }
+ return Explanation.match(scores[found], "within top " + k);
+ }
+
+ @Override
+ public Scorer scorer(LeafReaderContext context) {
+
+ return new Scorer(this) {
+ final int lower = segmentStarts[context.ord];
+ final int upper = segmentStarts[context.ord + 1];
+ int upTo = -1;
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return new DocIdSetIterator() {
+ @Override
+ public int docID() {
+ return docIdNoShadow();
+ }
+
+ @Override
+ public int nextDoc() {
+ if (upTo == -1) {
+ upTo = lower;
+ } else {
+ ++upTo;
+ }
+ return docIdNoShadow();
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ return slowAdvance(target);
+ }
+
+ @Override
+ public long cost() {
+ return upper - lower;
+ }
+ };
+ }
+
+ @Override
+ public float getMaxScore(int docid) {
+ docid += context.docBase;
+ float maxScore = 0;
+ for (int idx = Math.max(0, upTo); idx < upper && docs[idx] <= docid; idx++) {
+ maxScore = Math.max(maxScore, scores[idx]);
+ }
+ return maxScore;
+ }
+
+ @Override
+ public float score() {
+ return scores[upTo];
+ }
+
+ @Override
+ public int advanceShallow(int docid) {
+ int start = Math.max(upTo, lower);
+ int docidIndex = Arrays.binarySearch(docs, start, upper, docid + context.docBase);
+ if (docidIndex < 0) {
+ docidIndex = -1 - docidIndex;
+ }
+ if (docidIndex >= upper) {
+ return NO_MORE_DOCS;
+ }
+ return docs[docidIndex];
+ }
+
+ /**
+ * move the implementation of docID() into a differently-named method so we can call it
+ * from DocIDSetIterator.docID() even though this class is anonymous
+ *
+ * @return the current docid
+ */
+ private int docIdNoShadow() {
+ if (upTo == -1) {
+ return -1;
+ }
+ if (upTo >= upper) {
+ return NO_MORE_DOCS;
+ }
+ return docs[upTo] - context.docBase;
+ }
+
+ @Override
+ public int docID() {
+ return docIdNoShadow();
+ }
+ };
+ }
+
+ @Override
+ public boolean isCacheable(LeafReaderContext ctx) {
+ return true;
+ }
+ };
+ }
+
+ @Override
+ public String toString(String field) {
+ return "DocAndScore[" + k + "]";
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ visitor.visitLeaf(this);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj instanceof DocAndScoreQuery == false) {
+ return false;
+ }
+ return Arrays.equals(docs, ((DocAndScoreQuery) obj).docs)
+ && Arrays.equals(scores, ((DocAndScoreQuery) obj).scores);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(
+ DocAndScoreQuery.class.hashCode(), Arrays.hashCode(docs), Arrays.hashCode(scores));
+ }
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
new file mode 100644
index 00000000000..862f8f7c114
--- /dev/null
+++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnVectorQuery.java
@@ -0,0 +1,324 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search;
+
+import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
+import static org.apache.lucene.util.TestVectorUtil.randomVector;
+
+import java.io.IOException;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.KnnVectorField;
+import org.apache.lucene.document.StringField;
+import org.apache.lucene.index.DirectoryReader;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.IndexWriter;
+import org.apache.lucene.index.IndexWriterConfig;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+
+/** TestKnnVectorQuery tests KnnVectorQuery. */
+public class TestKnnVectorQuery extends LuceneTestCase {
+
+ public void testEquals() {
+ KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
+
+ assertEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 10));
+
+ assertNotEquals(null, q1);
+
+ assertNotEquals(q1, new TermQuery(new Term("f1", "x")));
+
+ assertNotEquals(q1, new KnnVectorQuery("f2", new float[] {0, 1}, 10));
+ assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {1, 1}, 10));
+ assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0, 1}, 2));
+ assertNotEquals(q1, new KnnVectorQuery("f1", new float[] {0}, 10));
+ }
+
+ public void testToString() {
+ KnnVectorQuery q1 = new KnnVectorQuery("f1", new float[] {0, 1}, 10);
+ assertEquals("", q1.toString("ignored"));
+ }
+
+ /**
+ * Tests if a KnnVectorQuery is rewritten to a MatchNoDocsQuery when there are no documents to
+ * match.
+ */
+ public void testEmptyIndex() throws IOException {
+ try (Directory indexStore = getIndexStore("field");
+ IndexReader reader = DirectoryReader.open(indexStore)) {
+ IndexSearcher searcher = newSearcher(reader);
+ KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {1, 2}, 10);
+ assertMatches(searcher, kvq, 0);
+ Query q = searcher.rewrite(kvq);
+ assertTrue(q instanceof MatchNoDocsQuery);
+ }
+ }
+
+ /**
+ * Tests that a KnnVectorQuery whose topK >= numDocs returns all the documents in score order
+ */
+ public void testFindAll() throws IOException {
+ try (Directory indexStore =
+ getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+ IndexReader reader = DirectoryReader.open(indexStore)) {
+ IndexSearcher searcher = newSearcher(reader);
+ KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0, 0}, 10);
+ assertMatches(searcher, kvq, reader.numDocs());
+ TopDocs topDocs = searcher.search(kvq, 3);
+ assertEquals(2, topDocs.scoreDocs[0].doc);
+ assertEquals(0, topDocs.scoreDocs[1].doc);
+ assertEquals(1, topDocs.scoreDocs[2].doc);
+ }
+ }
+
+ /** testDimensionMismatch */
+ public void testDimensionMismatch() throws IOException {
+ try (Directory indexStore =
+ getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+ IndexReader reader = DirectoryReader.open(indexStore)) {
+ IndexSearcher searcher = newSearcher(reader);
+ KnnVectorQuery kvq = new KnnVectorQuery("field", new float[] {0}, 10);
+ IllegalArgumentException e =
+ expectThrows(IllegalArgumentException.class, () -> searcher.search(kvq, 10));
+ assertEquals("vector dimensions differ: 1!=2", e.getMessage());
+ }
+ }
+
+ /** testNonVectorField */
+ public void testNonVectorField() throws IOException {
+ try (Directory indexStore =
+ getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+ IndexReader reader = DirectoryReader.open(indexStore)) {
+ IndexSearcher searcher = newSearcher(reader);
+ assertMatches(searcher, new KnnVectorQuery("xyzzy", new float[] {0}, 10), 0);
+ assertMatches(searcher, new KnnVectorQuery("id", new float[] {0}, 10), 0);
+ }
+ }
+
+ /** Test bad parameters */
+ public void testIllegalArguments() throws IOException {
+ expectThrows(
+ IllegalArgumentException.class, () -> new KnnVectorQuery("xx", new float[] {1}, 0));
+ }
+
+ public void testDifferentReader() throws IOException {
+ try (Directory indexStore =
+ getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
+ IndexReader reader = DirectoryReader.open(indexStore)) {
+ KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ Query dasq = query.rewrite(reader);
+ IndexSearcher leafSearcher = newSearcher(reader.leaves().get(0).reader());
+ expectThrows(
+ IllegalStateException.class,
+ () -> dasq.createWeight(leafSearcher, ScoreMode.COMPLETE, 1));
+ }
+ }
+
+ public void testAdvanceShallow() throws IOException {
+ try (Directory d = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+ for (int j = 0; j < 5; j++) {
+ Document doc = new Document();
+ doc.add(new KnnVectorField("field", new float[] {j, j}));
+ w.addDocument(doc);
+ }
+ }
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ Query dasq = query.rewrite(reader);
+ Scorer scorer =
+ dasq.createWeight(searcher, ScoreMode.COMPLETE, 1).scorer(reader.leaves().get(0));
+ // before advancing the iterator
+ assertEquals(1, scorer.advanceShallow(0));
+ assertEquals(1, scorer.advanceShallow(1));
+ assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
+
+ // after advancing the iterator
+ scorer.iterator().advance(2);
+ assertEquals(2, scorer.advanceShallow(0));
+ assertEquals(2, scorer.advanceShallow(2));
+ assertEquals(3, scorer.advanceShallow(3));
+ assertEquals(NO_MORE_DOCS, scorer.advanceShallow(10));
+ }
+ }
+ }
+
+ public void testScore() throws IOException {
+ try (Directory d = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+ for (int j = 0; j < 5; j++) {
+ Document doc = new Document();
+ doc.add(new KnnVectorField("field", new float[] {j, j}));
+ w.addDocument(doc);
+ }
+ }
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ assertEquals(1, reader.leaves().size());
+ IndexSearcher searcher = new IndexSearcher(reader);
+ KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ Query rewritten = query.rewrite(reader);
+ Weight weight = searcher.createWeight(rewritten, ScoreMode.COMPLETE, 1);
+ Scorer scorer = weight.scorer(reader.leaves().get(0));
+
+ // prior to advancing, score is 0
+ assertEquals(-1, scorer.docID());
+ expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
+
+ // test getMaxScore
+ assertEquals(0, scorer.getMaxScore(-1), 0);
+ assertEquals(0, scorer.getMaxScore(0), 0);
+ // This is 1 / ((l2distance((2,3), (2, 2)) = 1) + 1) = 0.5
+ assertEquals(1 / 2f, scorer.getMaxScore(2), 0);
+ assertEquals(1 / 2f, scorer.getMaxScore(Integer.MAX_VALUE), 0);
+
+ DocIdSetIterator it = scorer.iterator();
+ assertEquals(3, it.cost());
+ assertEquals(1, it.nextDoc());
+ assertEquals(1 / 6f, scorer.score(), 0);
+ assertEquals(3, it.advance(3));
+ assertEquals(1 / 2f, scorer.score(), 0);
+ assertEquals(NO_MORE_DOCS, it.advance(4));
+ expectThrows(ArrayIndexOutOfBoundsException.class, () -> scorer.score());
+ }
+ }
+ }
+
+ public void testExplain() throws IOException {
+ try (Directory d = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+ for (int j = 0; j < 5; j++) {
+ Document doc = new Document();
+ doc.add(new KnnVectorField("field", new float[] {j, j}));
+ w.addDocument(doc);
+ }
+ }
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = new IndexSearcher(reader);
+ KnnVectorQuery query = new KnnVectorQuery("field", new float[] {2, 3}, 3);
+ Explanation matched = searcher.explain(query, 2);
+ assertTrue(matched.isMatch());
+ assertEquals(1 / 2f, matched.getValue());
+ assertEquals(0, matched.getDetails().length);
+ assertEquals("within top 3", matched.getDescription());
+
+ Explanation nomatch = searcher.explain(query, 4);
+ assertFalse(nomatch.isMatch());
+ assertEquals(0f, nomatch.getValue());
+ assertEquals(0, matched.getDetails().length);
+ assertEquals("not in top 3", nomatch.getDescription());
+ }
+ }
+ }
+
+ /** Test that when vectors are abnormally distributed among segments, we still find the top K */
+ public void testSkewedIndex() throws IOException {
+ /* We have to choose the numbers carefully here so that some segment has more than the expected
+ * number of top K documents, but no more than K documents in total (otherwise we might occasionally
+ * randomly fail to find one).
+ */
+ try (Directory d = newDirectory()) {
+ try (IndexWriter w = new IndexWriter(d, new IndexWriterConfig())) {
+ int r = 0;
+ for (int i = 0; i < 5; i++) {
+ for (int j = 0; j < 5; j++) {
+ Document doc = new Document();
+ doc.add(new KnnVectorField("field", new float[] {r, r}));
+ w.addDocument(doc);
+ ++r;
+ }
+ w.flush();
+ }
+ }
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ TopDocs results = searcher.search(new KnnVectorQuery("field", new float[] {0, 0}, 8), 10);
+ assertEquals(8, results.scoreDocs.length);
+ assertEquals(0, results.scoreDocs[0].doc);
+ assertEquals(7, results.scoreDocs[7].doc);
+
+ // test some results in the middle of the sequence - also tests docid tiebreaking
+ results = searcher.search(new KnnVectorQuery("field", new float[] {10, 10}, 8), 10);
+ assertEquals(8, results.scoreDocs.length);
+ assertEquals(10, results.scoreDocs[0].doc);
+ assertEquals(6, results.scoreDocs[7].doc);
+ }
+ }
+ }
+
+ /** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */
+ public void testRandom() throws IOException {
+ int numDocs = atLeast(100);
+ int dimension = atLeast(5);
+ int numIters = atLeast(10);
+ boolean everyDocHasAVector = random().nextBoolean();
+ try (Directory d = newDirectory()) {
+ RandomIndexWriter w = new RandomIndexWriter(random(), d);
+ for (int i = 0; i < numDocs; i++) {
+ Document doc = new Document();
+ if (everyDocHasAVector || random().nextInt(10) != 2) {
+ doc.add(new KnnVectorField("field", randomVector(dimension)));
+ }
+ w.addDocument(doc);
+ }
+ w.close();
+ try (IndexReader reader = DirectoryReader.open(d)) {
+ IndexSearcher searcher = newSearcher(reader);
+ for (int i = 0; i < numIters; i++) {
+ int k = random().nextInt(100) + 1;
+ KnnVectorQuery query = new KnnVectorQuery("field", randomVector(dimension), k);
+ int n = random().nextInt(100) + 1;
+ TopDocs results = searcher.search(query, n);
+ int expected = Math.min(Math.min(n, k), reader.numDocs());
+ // we may get fewer results than requested if there are deletions, but this test doesn't
+ // test that
+ assert reader.hasDeletions() == false;
+ assertEquals(expected, results.scoreDocs.length);
+ assertTrue(results.totalHits.value >= results.scoreDocs.length);
+ // verify the results are in descending score order
+ float last = Float.MAX_VALUE;
+ for (ScoreDoc scoreDoc : results.scoreDocs) {
+ assertTrue(scoreDoc.score <= last);
+ last = scoreDoc.score;
+ }
+ }
+ }
+ }
+ }
+
+ private Directory getIndexStore(String field, float[]... contents) throws IOException {
+ Directory indexStore = newDirectory();
+ RandomIndexWriter writer = new RandomIndexWriter(random(), indexStore);
+ for (int i = 0; i < contents.length; ++i) {
+ Document doc = new Document();
+ doc.add(new KnnVectorField(field, contents[i]));
+ doc.add(new StringField("id", "id" + i, Field.Store.NO));
+ writer.addDocument(doc);
+ }
+ writer.close();
+ return indexStore;
+ }
+
+ private void assertMatches(IndexSearcher searcher, Query q, int expectedMatches)
+ throws IOException {
+ ScoreDoc[] result = searcher.search(q, 1000).scoreDocs;
+ assertEquals(expectedMatches, result.length);
+ }
+}
diff --git a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
index d373fa1d1a4..1ebd0562893 100644
--- a/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
+++ b/lucene/core/src/test/org/apache/lucene/util/TestVectorUtil.java
@@ -16,6 +16,8 @@
*/
package org.apache.lucene.util;
+import java.util.Random;
+
public class TestVectorUtil extends LuceneTestCase {
public static final double DELTA = 1e-4;
@@ -81,7 +83,7 @@ public class TestVectorUtil extends LuceneTestCase {
expectThrows(IllegalArgumentException.class, () -> VectorUtil.l2normalize(v));
}
- private float l2(float[] v) {
+ private static float l2(float[] v) {
float l2 = 0;
for (float x : v) {
l2 += x * x;
@@ -89,7 +91,7 @@ public class TestVectorUtil extends LuceneTestCase {
return l2;
}
- private float[] negative(float[] v) {
+ private static float[] negative(float[] v) {
float[] u = new float[v.length];
for (int i = 0; i < v.length; i++) {
u[i] = -v[i];
@@ -97,10 +99,15 @@ public class TestVectorUtil extends LuceneTestCase {
return u;
}
- private float[] randomVector() {
- float[] v = new float[random().nextInt(100) + 1];
- for (int i = 0; i < v.length; i++) {
- v[i] = random().nextFloat();
+ private static float[] randomVector() {
+ return randomVector(random().nextInt(100) + 1);
+ }
+
+ public static float[] randomVector(int dim) {
+ float[] v = new float[dim];
+ Random random = random();
+ for (int i = 0; i < dim; i++) {
+ v[i] = random.nextFloat();
}
return v;
}