Enable boosts on JoinUtil queries (#12388)

Boosts should not be ignored by queries returned from JoinUtil
This commit is contained in:
Alan Woodward 2023-06-26 09:47:14 +01:00 committed by GitHub
parent 7f10dca1e5
commit edd799824f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 22 deletions

View File

@ -138,7 +138,8 @@ Optimizations
Bug Fixes
---------------------
(No changes)
* GITHUB#12388: JoinUtil queries were ignoring boosts. (Alan Woodward)
Other
---------------------

View File

@ -27,19 +27,21 @@ abstract class BaseGlobalOrdinalScorer extends Scorer {
final SortedDocValues values;
final DocIdSetIterator approximation;
final float boost;
float score;
public BaseGlobalOrdinalScorer(
Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer) {
Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer, float boost) {
super(weight);
this.values = values;
this.approximation = approximationScorer;
this.boost = boost;
}
@Override
public float score() throws IOException {
return score;
return score * boost;
}
@Override

View File

@ -215,7 +215,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable {
SortedDocValues values,
DocIdSetIterator approximationScorer,
LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximationScorer);
super(weight, values, approximationScorer, 1);
this.score = score;
this.foundOrds = foundOrds;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
@ -255,7 +255,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable {
LongBitSet foundOrds,
SortedDocValues values,
DocIdSetIterator approximationScorer) {
super(weight, values, approximationScorer);
super(weight, values, approximationScorer, 1);
this.score = score;
this.foundOrds = foundOrds;
}

View File

@ -17,7 +17,6 @@
package org.apache.lucene.search.join;
import java.io.IOException;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues;
@ -117,7 +116,8 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
}
return new W(
this,
toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f));
toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f),
boost);
}
@Override
@ -169,13 +169,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
final class W extends FilterWeight {
W(Query query, Weight approximationWeight) {
final float boost;
W(Query query, Weight approximationWeight, float boost) {
super(query, approximationWeight);
this.boost = boost;
}
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) {
return Explanation.noMatch("Not a match");
}
@ -197,12 +200,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
}
float score = collector.score(ord);
return Explanation.match(score, "A match, join value " + Term.toString(joinValue));
if (boost == 1.0f) {
return Explanation.match(score, "A match, join value " + Term.toString(joinValue));
}
return Explanation.match(
score * boost, "A match, join value " + Term.toString(joinValue) + "^" + boost);
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField);
SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) {
return null;
}
@ -214,11 +221,13 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
return new OrdinalMapScorer(
this,
collector,
boost,
values,
approximationScorer.iterator(),
globalOrds.getGlobalOrds(context.ord));
} else {
return new SegmentOrdinalScorer(this, collector, values, approximationScorer.iterator());
return new SegmentOrdinalScorer(
this, collector, values, boost, approximationScorer.iterator());
}
}
@ -239,10 +248,11 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
public OrdinalMapScorer(
Weight weight,
GlobalOrdinalsWithScoreCollector collector,
float boost,
SortedDocValues values,
DocIdSetIterator approximation,
LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximation);
super(weight, values, approximation, boost);
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
this.collector = collector;
}
@ -280,8 +290,9 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
Weight weight,
GlobalOrdinalsWithScoreCollector collector,
SortedDocValues values,
float boost,
DocIdSetIterator approximation) {
super(weight, values, approximation);
super(weight, values, approximation, boost);
this.collector = collector;
}

View File

@ -151,8 +151,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
postingsEnum = segmentTermsEnum.postings(postingsEnum, PostingsEnum.NONE);
if (postingsEnum.advance(doc) == doc) {
final float score = TermsIncludingScoreQuery.this.scores[ords[i]];
return Explanation.match(
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString());
if (boost == 1.0f) {
return Explanation.match(
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString());
} else {
return Explanation.match(
score * boost,
"Score based on join value "
+ segmentTermsEnum.term().utf8ToString()
+ "^"
+ boost);
}
}
}
}
@ -172,9 +181,11 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
TermsEnum segmentTermsEnum = terms.iterator();
if (multipleValuesPerDocument) {
return new MVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost);
return new MVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
} else {
return new SVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost);
return new SVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
}
}
@ -190,14 +201,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
final DocIdSetIterator matchingDocsIterator;
final float[] scores;
final long cost;
final float boost;
SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException {
SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
throws IOException {
super(weight);
FixedBitSet matchingDocs = new FixedBitSet(maxDoc);
this.scores = new float[maxDoc];
fillDocsAndScores(matchingDocs, termsEnum);
this.matchingDocsIterator = new BitSetIterator(matchingDocs, cost);
this.cost = cost;
this.boost = boost;
}
protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum)
@ -223,7 +237,7 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
@Override
public float score() throws IOException {
return scores[docID()];
return scores[docID()] * boost;
}
@Override
@ -246,8 +260,9 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
// related documents.
class MVInOrderScorer extends SVInOrderScorer {
MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException {
super(weight, termsEnum, maxDoc, cost);
MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
throws IOException {
super(weight, termsEnum, maxDoc, cost, boost);
}
@Override

View File

@ -68,6 +68,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.analysis.MockTokenizer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet;
@ -689,6 +690,7 @@ public class TestJoinUtil extends LuceneTestCase {
}
}
assertEquals(expectedCount, totalHits);
checkBoost(joinQuery, searcher);
}
searcher.getIndexReader().close();
dir.close();
@ -997,6 +999,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode max.
joinQuery =
@ -1011,6 +1014,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode total
joinQuery =
@ -1025,6 +1029,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode avg
joinQuery =
@ -1039,11 +1044,20 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
indexSearcher.getIndexReader().close();
dir.close();
}
private void checkBoost(Query query, IndexSearcher searcher) throws IOException {
TopDocs result = searcher.search(query, 10);
Query boostedQuery = new BoostQuery(query, 10);
TopDocs boostedResult = searcher.search(boostedQuery, 10);
assertEquals(result.scoreDocs[0].score * 10, boostedResult.scoreDocs[0].score, 0.000001f);
QueryUtils.checkExplanations(boostedQuery, searcher);
}
public void testEquals() throws Exception {
final int numDocs = atLeast(random(), 50);
try (final Directory dir = newDirectory()) {