Improve DocAndScoreQuery#toString (#12148)

Tiny improvements to DocAndScoreQuery:
* Make toString more informative
* Remove unnecessary 'k' parameter
This commit is contained in:
Julie Tibshirani 2023-02-15 09:50:55 -08:00 committed by GitHub
parent 7baa01b3c2
commit 54044a82a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 19 deletions

View File

@ -197,7 +197,7 @@ abstract class AbstractKnnVectorQuery extends Query {
scores[i] = topK.scoreDocs[i].score; scores[i] = topK.scoreDocs[i].score;
} }
int[] segmentStarts = findSegmentStarts(reader, docs); int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id()); return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
} }
private int[] findSegmentStarts(IndexReader reader, int[] docs) { private int[] findSegmentStarts(IndexReader reader, int[] docs) {
@ -263,7 +263,6 @@ abstract class AbstractKnnVectorQuery extends Query {
/** Caches the results of a KnnVector search: a list of docs and their scores */ /** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query { static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs; private final int[] docs;
private final float[] scores; private final float[] scores;
private final int[] segmentStarts; private final int[] segmentStarts;
@ -272,7 +271,6 @@ abstract class AbstractKnnVectorQuery extends Query {
/** /**
* Constructor * Constructor
* *
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order * @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents * @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching * @param segmentStarts the indexes in docs and scores corresponding to the first matching
@ -282,9 +280,7 @@ abstract class AbstractKnnVectorQuery extends Query {
* @param contextIdentity an object identifying the reader context that was used to build this * @param contextIdentity an object identifying the reader context that was used to build this
* query * query
*/ */
DocAndScoreQuery( DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
this.docs = docs; this.docs = docs;
this.scores = scores; this.scores = scores;
this.segmentStarts = segmentStarts; this.segmentStarts = segmentStarts;
@ -302,9 +298,9 @@ abstract class AbstractKnnVectorQuery extends Query {
public Explanation explain(LeafReaderContext context, int doc) { public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc + context.docBase); int found = Arrays.binarySearch(docs, doc + context.docBase);
if (found < 0) { if (found < 0) {
return Explanation.noMatch("not in top " + k); return Explanation.noMatch("not in top " + docs.length + " docs");
} }
return Explanation.match(scores[found] * boost, "within top " + k); return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs");
} }
@Override @Override
@ -405,7 +401,7 @@ abstract class AbstractKnnVectorQuery extends Query {
@Override @Override
public String toString(String field) { public String toString(String field) {
return "DocAndScore[" + k + "]"; return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]";
} }
@Override @Override

View File

@ -379,13 +379,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertTrue(matched.isMatch()); assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue()); assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length); assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription()); assertEquals("within top 3 docs", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4); Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch()); assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue()); assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length); assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription()); assertEquals("not in top 3 docs", nomatch.getDescription());
} }
} }
} }
@ -407,13 +407,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertTrue(matched.isMatch()); assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue()); assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length); assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription()); assertEquals("within top 3 docs", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4); Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch()); assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue()); assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length); assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription()); assertEquals("not in top 3 docs", nomatch.getDescription());
} }
} }
} }

View File

@ -16,10 +16,14 @@
*/ */
package org.apache.lucene.search; package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.document.Field; import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil; import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@ -65,9 +69,16 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return bytes; return bytes;
} }
public void testToString() { public void testToString() throws IOException {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10); try (Directory indexStore =
assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored")); getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0, 1}, 10);
assertEquals("KnnByteVectorQuery:field[0,...][10]", query.toString("ignored"));
Query rewritten = query.rewrite(newSearcher(reader));
assertEquals("DocAndScoreQuery[0,...][1.0,...]", rewritten.toString("ignored"));
}
} }
public void testGetTarget() { public void testGetTarget() {

View File

@ -60,9 +60,16 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnFloatVectorField(name, vector); return new KnnFloatVectorField(name, vector);
} }
public void testToString() { public void testToString() throws IOException {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10); try (Directory indexStore =
assertEquals("KnnFloatVectorQuery:f1[0.0,...][10]", q1.toString("ignored")); getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10);
assertEquals("KnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored"));
Query rewritten = query.rewrite(newSearcher(reader));
assertEquals("DocAndScoreQuery[0,...][1.0,...]", rewritten.toString("ignored"));
}
} }
public void testGetTarget() { public void testGetTarget() {