Improve DocAndScoreQuery#toString (#12148)

Tiny improvements to DocAndScoreQuery:
* Make toString more informative
* Remove unnecessary 'k' parameter
This commit is contained in:
Julie Tibshirani 2023-02-15 09:50:55 -08:00 committed by GitHub
parent 7baa01b3c2
commit 54044a82a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 33 additions and 19 deletions

View File

@ -197,7 +197,7 @@ abstract class AbstractKnnVectorQuery extends Query {
scores[i] = topK.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(k, docs, scores, segmentStarts, reader.getContext().id());
return new DocAndScoreQuery(docs, scores, segmentStarts, reader.getContext().id());
}
private int[] findSegmentStarts(IndexReader reader, int[] docs) {
@ -263,7 +263,6 @@ abstract class AbstractKnnVectorQuery extends Query {
/** Caches the results of a KnnVector search: a list of docs and their scores */
static class DocAndScoreQuery extends Query {
private final int k;
private final int[] docs;
private final float[] scores;
private final int[] segmentStarts;
@ -272,7 +271,6 @@ abstract class AbstractKnnVectorQuery extends Query {
/**
* Constructor
*
* @param k the number of documents requested
* @param docs the global docids of documents that match, in ascending order
* @param scores the scores of the matching documents
* @param segmentStarts the indexes in docs and scores corresponding to the first matching
@ -282,9 +280,7 @@ abstract class AbstractKnnVectorQuery extends Query {
* @param contextIdentity an object identifying the reader context that was used to build this
* query
*/
DocAndScoreQuery(
int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.k = k;
DocAndScoreQuery(int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
@ -302,9 +298,9 @@ abstract class AbstractKnnVectorQuery extends Query {
public Explanation explain(LeafReaderContext context, int doc) {
int found = Arrays.binarySearch(docs, doc + context.docBase);
if (found < 0) {
return Explanation.noMatch("not in top " + k);
return Explanation.noMatch("not in top " + docs.length + " docs");
}
return Explanation.match(scores[found] * boost, "within top " + k);
return Explanation.match(scores[found] * boost, "within top " + docs.length + " docs");
}
@Override
@ -405,7 +401,7 @@ abstract class AbstractKnnVectorQuery extends Query {
@Override
public String toString(String field) {
return "DocAndScore[" + k + "]";
return "DocAndScoreQuery[" + docs[0] + ",...][" + scores[0] + ",...]";
}
@Override

View File

@ -379,13 +379,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription());
assertEquals("within top 3 docs", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription());
assertEquals("not in top 3 docs", nomatch.getDescription());
}
}
}
@ -407,13 +407,13 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase {
assertTrue(matched.isMatch());
assertEquals(1 / 2f, matched.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("within top 3", matched.getDescription());
assertEquals("within top 3 docs", matched.getDescription());
Explanation nomatch = searcher.explain(query, 4);
assertFalse(nomatch.isMatch());
assertEquals(0f, nomatch.getValue());
assertEquals(0, matched.getDetails().length);
assertEquals("not in top 3", nomatch.getDescription());
assertEquals("not in top 3 docs", nomatch.getDescription());
}
}
}

View File

@ -16,10 +16,14 @@
*/
package org.apache.lucene.search;
import java.io.IOException;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.TestVectorUtil;
public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
@ -65,9 +69,16 @@ public class TestKnnByteVectorQuery extends BaseKnnVectorQueryTestCase {
return bytes;
}
public void testToString() {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
assertEquals("KnnByteVectorQuery:f1[0,...][10]", q1.toString("ignored"));
public void testToString() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0, 1}, 10);
assertEquals("KnnByteVectorQuery:field[0,...][10]", query.toString("ignored"));
Query rewritten = query.rewrite(newSearcher(reader));
assertEquals("DocAndScoreQuery[0,...][1.0,...]", rewritten.toString("ignored"));
}
}
public void testGetTarget() {

View File

@ -60,9 +60,16 @@ public class TestKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase {
return new KnnFloatVectorField(name, vector);
}
public void testToString() {
AbstractKnnVectorQuery q1 = getKnnVectorQuery("f1", new float[] {0, 1}, 10);
assertEquals("KnnFloatVectorQuery:f1[0.0,...][10]", q1.toString("ignored"));
public void testToString() throws IOException {
try (Directory indexStore =
getIndexStore("field", new float[] {0, 1}, new float[] {1, 2}, new float[] {0, 0});
IndexReader reader = DirectoryReader.open(indexStore)) {
AbstractKnnVectorQuery query = getKnnVectorQuery("field", new float[] {0.0f, 1.0f}, 10);
assertEquals("KnnFloatVectorQuery:field[0.0,...][10]", query.toString("ignored"));
Query rewritten = query.rewrite(newSearcher(reader));
assertEquals("DocAndScoreQuery[0,...][1.0,...]", rewritten.toString("ignored"));
}
}
public void testGetTarget() {