LUCENE-7998: QueryDoubleValuesSource

This commit is contained in:
Alan Woodward 2017-10-19 09:06:05 +01:00
parent 75825d240f
commit 423677f20e
2 changed files with 157 additions and 0 deletions

View File

@ -512,4 +512,127 @@ public abstract class DoubleValuesSource implements SegmentCacheable {
};
}
/**
* Create a DoubleValuesSource that returns the score of a particular query
*/
public static DoubleValuesSource fromQuery(Query query) {
return new QueryDoubleValuesSource(query);
}
private static class QueryDoubleValuesSource extends DoubleValuesSource {
private final Query query;
private QueryDoubleValuesSource(Query query) {
this.query = query;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
QueryDoubleValuesSource that = (QueryDoubleValuesSource) o;
return Objects.equals(query, that.query);
}
@Override
public int hashCode() {
return Objects.hash(query);
}
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
throw new UnsupportedOperationException("This DoubleValuesSource must be rewritten");
}
@Override
public boolean needsScores() {
return false;
}
@Override
public DoubleValuesSource rewrite(IndexSearcher searcher) throws IOException {
return new WeightDoubleValuesSource(searcher.rewrite(query).createWeight(searcher, true, 1f));
}
@Override
public String toString() {
return "score(" + query.toString() + ")";
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
private static class WeightDoubleValuesSource extends DoubleValuesSource {
private final Weight weight;
private WeightDoubleValuesSource(Weight weight) {
this.weight = weight;
}
@Override
public DoubleValues getValues(LeafReaderContext ctx, DoubleValues scores) throws IOException {
Scorer scorer = weight.scorer(ctx);
if (scorer == null)
return DoubleValues.EMPTY;
DocIdSetIterator it = scorer.iterator();
return new DoubleValues() {
@Override
public double doubleValue() throws IOException {
return scorer.score();
}
@Override
public boolean advanceExact(int doc) throws IOException {
if (it.docID() > doc)
return false;
return it.docID() == doc || it.advance(doc) == doc;
}
};
}
@Override
public Explanation explain(LeafReaderContext ctx, int docId, Explanation scoreExplanation) throws IOException {
return weight.explain(ctx, docId);
}
@Override
public boolean needsScores() {
return false;
}
@Override
public DoubleValuesSource rewrite(IndexSearcher searcher) throws IOException {
return this;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
WeightDoubleValuesSource that = (WeightDoubleValuesSource) o;
return Objects.equals(weight, that.weight);
}
@Override
public int hashCode() {
return Objects.hash(weight);
}
@Override
public String toString() {
return "score(" + weight.parentQuery.toString() + ")";
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return false;
}
}
}

View File

@ -186,6 +186,7 @@ public class TestDoubleValuesSource extends LuceneTestCase {
public void testExplanations() throws Exception {
for (Query q : testQueries) {
testExplanations(q, DoubleValuesSource.fromQuery(new TermQuery(new Term("english", "one"))));
testExplanations(q, DoubleValuesSource.fromIntField("int"));
testExplanations(q, DoubleValuesSource.fromLongField("long"));
testExplanations(q, DoubleValuesSource.fromFloatField("float"));
@ -230,4 +231,37 @@ public class TestDoubleValuesSource extends LuceneTestCase {
});
}
public void testQueryDoubleValuesSource() throws Exception {
Query q = new TermQuery(new Term("english", "two"));
DoubleValuesSource vs = DoubleValuesSource.fromQuery(q).rewrite(searcher);
searcher.search(q, new SimpleCollector() {
DoubleValues v;
Scorer scorer;
LeafReaderContext ctx;
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
this.ctx = context;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
this.scorer = scorer;
this.v = vs.getValues(this.ctx, DoubleValuesSource.fromScorer(scorer));
}
@Override
public void collect(int doc) throws IOException {
assertTrue(v.advanceExact(doc));
assertEquals(scorer.score(), v.doubleValue(), 0.00001);
}
@Override
public boolean needsScores() {
return true;
}
});
}
}