Enable boosts on JoinUtil queries (#12388)

Boosts should not be ignored by queries returned from JoinUtil
This commit is contained in:
Alan Woodward 2023-06-26 09:47:14 +01:00 committed by GitHub
parent 7f10dca1e5
commit edd799824f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 22 deletions

View File

@ -138,7 +138,8 @@ Optimizations
Bug Fixes Bug Fixes
--------------------- ---------------------
(No changes)
* GITHUB#12388: JoinUtil queries were ignoring boosts. (Alan Woodward)
Other Other
--------------------- ---------------------

View File

@ -27,19 +27,21 @@ abstract class BaseGlobalOrdinalScorer extends Scorer {
final SortedDocValues values; final SortedDocValues values;
final DocIdSetIterator approximation; final DocIdSetIterator approximation;
final float boost;
float score; float score;
public BaseGlobalOrdinalScorer( public BaseGlobalOrdinalScorer(
Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer) { Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer, float boost) {
super(weight); super(weight);
this.values = values; this.values = values;
this.approximation = approximationScorer; this.approximation = approximationScorer;
this.boost = boost;
} }
@Override @Override
public float score() throws IOException { public float score() throws IOException {
return score; return score * boost;
} }
@Override @Override

View File

@ -215,7 +215,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable {
SortedDocValues values, SortedDocValues values,
DocIdSetIterator approximationScorer, DocIdSetIterator approximationScorer,
LongValues segmentOrdToGlobalOrdLookup) { LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximationScorer); super(weight, values, approximationScorer, 1);
this.score = score; this.score = score;
this.foundOrds = foundOrds; this.foundOrds = foundOrds;
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
@ -255,7 +255,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable {
LongBitSet foundOrds, LongBitSet foundOrds,
SortedDocValues values, SortedDocValues values,
DocIdSetIterator approximationScorer) { DocIdSetIterator approximationScorer) {
super(weight, values, approximationScorer); super(weight, values, approximationScorer, 1);
this.score = score; this.score = score;
this.foundOrds = foundOrds; this.foundOrds = foundOrds;
} }

View File

@ -17,7 +17,6 @@
package org.apache.lucene.search.join; package org.apache.lucene.search.join;
import java.io.IOException; import java.io.IOException;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.OrdinalMap; import org.apache.lucene.index.OrdinalMap;
import org.apache.lucene.index.SortedDocValues; import org.apache.lucene.index.SortedDocValues;
@ -117,7 +116,8 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
} }
return new W( return new W(
this, this,
toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f)); toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f),
boost);
} }
@Override @Override
@ -169,13 +169,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
final class W extends FilterWeight { final class W extends FilterWeight {
W(Query query, Weight approximationWeight) { final float boost;
W(Query query, Weight approximationWeight, float boost) {
super(query, approximationWeight); super(query, approximationWeight);
this.boost = boost;
} }
@Override @Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException { public Explanation explain(LeafReaderContext context, int doc) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField); SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) { if (values == null) {
return Explanation.noMatch("Not a match"); return Explanation.noMatch("Not a match");
} }
@ -197,12 +200,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
} }
float score = collector.score(ord); float score = collector.score(ord);
return Explanation.match(score, "A match, join value " + Term.toString(joinValue)); if (boost == 1.0f) {
return Explanation.match(score, "A match, join value " + Term.toString(joinValue));
}
return Explanation.match(
score * boost, "A match, join value " + Term.toString(joinValue) + "^" + boost);
} }
@Override @Override
public Scorer scorer(LeafReaderContext context) throws IOException { public Scorer scorer(LeafReaderContext context) throws IOException {
SortedDocValues values = DocValues.getSorted(context.reader(), joinField); SortedDocValues values = context.reader().getSortedDocValues(joinField);
if (values == null) { if (values == null) {
return null; return null;
} }
@ -214,11 +221,13 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
return new OrdinalMapScorer( return new OrdinalMapScorer(
this, this,
collector, collector,
boost,
values, values,
approximationScorer.iterator(), approximationScorer.iterator(),
globalOrds.getGlobalOrds(context.ord)); globalOrds.getGlobalOrds(context.ord));
} else { } else {
return new SegmentOrdinalScorer(this, collector, values, approximationScorer.iterator()); return new SegmentOrdinalScorer(
this, collector, values, boost, approximationScorer.iterator());
} }
} }
@ -239,10 +248,11 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
public OrdinalMapScorer( public OrdinalMapScorer(
Weight weight, Weight weight,
GlobalOrdinalsWithScoreCollector collector, GlobalOrdinalsWithScoreCollector collector,
float boost,
SortedDocValues values, SortedDocValues values,
DocIdSetIterator approximation, DocIdSetIterator approximation,
LongValues segmentOrdToGlobalOrdLookup) { LongValues segmentOrdToGlobalOrdLookup) {
super(weight, values, approximation); super(weight, values, approximation, boost);
this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup;
this.collector = collector; this.collector = collector;
} }
@ -280,8 +290,9 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable {
Weight weight, Weight weight,
GlobalOrdinalsWithScoreCollector collector, GlobalOrdinalsWithScoreCollector collector,
SortedDocValues values, SortedDocValues values,
float boost,
DocIdSetIterator approximation) { DocIdSetIterator approximation) {
super(weight, values, approximation); super(weight, values, approximation, boost);
this.collector = collector; this.collector = collector;
} }

View File

@ -151,8 +151,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
postingsEnum = segmentTermsEnum.postings(postingsEnum, PostingsEnum.NONE); postingsEnum = segmentTermsEnum.postings(postingsEnum, PostingsEnum.NONE);
if (postingsEnum.advance(doc) == doc) { if (postingsEnum.advance(doc) == doc) {
final float score = TermsIncludingScoreQuery.this.scores[ords[i]]; final float score = TermsIncludingScoreQuery.this.scores[ords[i]];
return Explanation.match( if (boost == 1.0f) {
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString()); return Explanation.match(
score, "Score based on join value " + segmentTermsEnum.term().utf8ToString());
} else {
return Explanation.match(
score * boost,
"Score based on join value "
+ segmentTermsEnum.term().utf8ToString()
+ "^"
+ boost);
}
} }
} }
} }
@ -172,9 +181,11 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
TermsEnum segmentTermsEnum = terms.iterator(); TermsEnum segmentTermsEnum = terms.iterator();
if (multipleValuesPerDocument) { if (multipleValuesPerDocument) {
return new MVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost); return new MVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
} else { } else {
return new SVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost); return new SVInOrderScorer(
this, segmentTermsEnum, context.reader().maxDoc(), cost, boost);
} }
} }
@ -190,14 +201,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
final DocIdSetIterator matchingDocsIterator; final DocIdSetIterator matchingDocsIterator;
final float[] scores; final float[] scores;
final long cost; final long cost;
final float boost;
SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
throws IOException {
super(weight); super(weight);
FixedBitSet matchingDocs = new FixedBitSet(maxDoc); FixedBitSet matchingDocs = new FixedBitSet(maxDoc);
this.scores = new float[maxDoc]; this.scores = new float[maxDoc];
fillDocsAndScores(matchingDocs, termsEnum); fillDocsAndScores(matchingDocs, termsEnum);
this.matchingDocsIterator = new BitSetIterator(matchingDocs, cost); this.matchingDocsIterator = new BitSetIterator(matchingDocs, cost);
this.cost = cost; this.cost = cost;
this.boost = boost;
} }
protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum) protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum)
@ -223,7 +237,7 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
@Override @Override
public float score() throws IOException { public float score() throws IOException {
return scores[docID()]; return scores[docID()] * boost;
} }
@Override @Override
@ -246,8 +260,9 @@ class TermsIncludingScoreQuery extends Query implements Accountable {
// related documents. // related documents.
class MVInOrderScorer extends SVInOrderScorer { class MVInOrderScorer extends SVInOrderScorer {
MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost)
super(weight, termsEnum, maxDoc, cost); throws IOException {
super(weight, termsEnum, maxDoc, cost, boost);
} }
@Override @Override

View File

@ -68,6 +68,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.analysis.MockTokenizer; import org.apache.lucene.tests.analysis.MockTokenizer;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.QueryUtils;
import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase;
import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TestUtil;
import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSet;
@ -689,6 +690,7 @@ public class TestJoinUtil extends LuceneTestCase {
} }
} }
assertEquals(expectedCount, totalHits); assertEquals(expectedCount, totalHits);
checkBoost(joinQuery, searcher);
} }
searcher.getIndexReader().close(); searcher.getIndexReader().close();
dir.close(); dir.close();
@ -997,6 +999,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value); assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc); assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc); assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode max. // Score mode max.
joinQuery = joinQuery =
@ -1011,6 +1014,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value); assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc); assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc); assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode total // Score mode total
joinQuery = joinQuery =
@ -1025,6 +1029,7 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value); assertEquals(2, result.totalHits.value);
assertEquals(0, result.scoreDocs[0].doc); assertEquals(0, result.scoreDocs[0].doc);
assertEquals(3, result.scoreDocs[1].doc); assertEquals(3, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
// Score mode avg // Score mode avg
joinQuery = joinQuery =
@ -1039,11 +1044,20 @@ public class TestJoinUtil extends LuceneTestCase {
assertEquals(2, result.totalHits.value); assertEquals(2, result.totalHits.value);
assertEquals(3, result.scoreDocs[0].doc); assertEquals(3, result.scoreDocs[0].doc);
assertEquals(0, result.scoreDocs[1].doc); assertEquals(0, result.scoreDocs[1].doc);
checkBoost(joinQuery, indexSearcher);
indexSearcher.getIndexReader().close(); indexSearcher.getIndexReader().close();
dir.close(); dir.close();
} }
private void checkBoost(Query query, IndexSearcher searcher) throws IOException {
TopDocs result = searcher.search(query, 10);
Query boostedQuery = new BoostQuery(query, 10);
TopDocs boostedResult = searcher.search(boostedQuery, 10);
assertEquals(result.scoreDocs[0].score * 10, boostedResult.scoreDocs[0].score, 0.000001f);
QueryUtils.checkExplanations(boostedQuery, searcher);
}
public void testEquals() throws Exception { public void testEquals() throws Exception {
final int numDocs = atLeast(random(), 50); final int numDocs = atLeast(random(), 50);
try (final Directory dir = newDirectory()) { try (final Directory dir = newDirectory()) {