From edd799824f977ec54a6c7fc3662a4da84fa99c8e Mon Sep 17 00:00:00 2001 From: Alan Woodward Date: Mon, 26 Jun 2023 09:47:14 +0100 Subject: [PATCH] Enable boosts on JoinUtil queries (#12388) Boosts should not be ignored by queries returned from JoinUtil --- lucene/CHANGES.txt | 3 +- .../search/join/BaseGlobalOrdinalScorer.java | 6 ++-- .../search/join/GlobalOrdinalsQuery.java | 4 +-- .../join/GlobalOrdinalsWithScoreQuery.java | 29 +++++++++++------ .../search/join/TermsIncludingScoreQuery.java | 31 ++++++++++++++----- .../lucene/search/join/TestJoinUtil.java | 14 +++++++++ 6 files changed, 65 insertions(+), 22 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index f94dab42688..f7c2ab77b63 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -138,7 +138,8 @@ Optimizations Bug Fixes --------------------- -(No changes) + +* GITHUB#12388: JoinUtil queries were ignoring boosts. (Alan Woodward) Other --------------------- diff --git a/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java index 2bed00ccec3..881cbce37ab 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/BaseGlobalOrdinalScorer.java @@ -27,19 +27,21 @@ abstract class BaseGlobalOrdinalScorer extends Scorer { final SortedDocValues values; final DocIdSetIterator approximation; + final float boost; float score; public BaseGlobalOrdinalScorer( - Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer) { + Weight weight, SortedDocValues values, DocIdSetIterator approximationScorer, float boost) { super(weight); this.values = values; this.approximation = approximationScorer; + this.boost = boost; } @Override public float score() throws IOException { - return score; + return score * boost; } @Override diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java index c44b6a2e486..728b701d7cc 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsQuery.java @@ -215,7 +215,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable { SortedDocValues values, DocIdSetIterator approximationScorer, LongValues segmentOrdToGlobalOrdLookup) { - super(weight, values, approximationScorer); + super(weight, values, approximationScorer, 1); this.score = score; this.foundOrds = foundOrds; this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; @@ -255,7 +255,7 @@ final class GlobalOrdinalsQuery extends Query implements Accountable { LongBitSet foundOrds, SortedDocValues values, DocIdSetIterator approximationScorer) { - super(weight, values, approximationScorer); + super(weight, values, approximationScorer, 1); this.score = score; this.foundOrds = foundOrds; } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java index 361ceb0f58e..1d85661003e 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/GlobalOrdinalsWithScoreQuery.java @@ -17,7 +17,6 @@ package org.apache.lucene.search.join; import java.io.IOException; -import org.apache.lucene.index.DocValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.OrdinalMap; import org.apache.lucene.index.SortedDocValues; @@ -117,7 +116,8 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { } return new W( this, - toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f)); + toQuery.createWeight(searcher, org.apache.lucene.search.ScoreMode.COMPLETE_NO_SCORES, 1f), + boost); } @Override @@ -169,13 +169,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { final class W extends FilterWeight { - W(Query query, Weight approximationWeight) { + final float boost; + + W(Query query, Weight approximationWeight, float boost) { super(query, approximationWeight); + this.boost = boost; } @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + SortedDocValues values = context.reader().getSortedDocValues(joinField); if (values == null) { return Explanation.noMatch("Not a match"); } @@ -197,12 +200,16 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { } float score = collector.score(ord); - return Explanation.match(score, "A match, join value " + Term.toString(joinValue)); + if (boost == 1.0f) { + return Explanation.match(score, "A match, join value " + Term.toString(joinValue)); + } + return Explanation.match( + score * boost, "A match, join value " + Term.toString(joinValue) + "^" + boost); } @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SortedDocValues values = DocValues.getSorted(context.reader(), joinField); + SortedDocValues values = context.reader().getSortedDocValues(joinField); if (values == null) { return null; } @@ -214,11 +221,13 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { return new OrdinalMapScorer( this, collector, + boost, values, approximationScorer.iterator(), globalOrds.getGlobalOrds(context.ord)); } else { - return new SegmentOrdinalScorer(this, collector, values, approximationScorer.iterator()); + return new SegmentOrdinalScorer( + this, collector, values, boost, approximationScorer.iterator()); } } @@ -239,10 +248,11 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { public OrdinalMapScorer( Weight weight, GlobalOrdinalsWithScoreCollector collector, + float boost, SortedDocValues values, DocIdSetIterator approximation, LongValues segmentOrdToGlobalOrdLookup) { - super(weight, values, approximation); + super(weight, values, approximation, boost); this.segmentOrdToGlobalOrdLookup = segmentOrdToGlobalOrdLookup; this.collector = collector; } @@ -280,8 +290,9 @@ final class GlobalOrdinalsWithScoreQuery extends Query implements Accountable { Weight weight, GlobalOrdinalsWithScoreCollector collector, SortedDocValues values, + float boost, DocIdSetIterator approximation) { - super(weight, values, approximation); + super(weight, values, approximation, boost); this.collector = collector; } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java index 2d343e42b09..b51fbbbb566 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/TermsIncludingScoreQuery.java @@ -151,8 +151,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable { postingsEnum = segmentTermsEnum.postings(postingsEnum, PostingsEnum.NONE); if (postingsEnum.advance(doc) == doc) { final float score = TermsIncludingScoreQuery.this.scores[ords[i]]; - return Explanation.match( - score, "Score based on join value " + segmentTermsEnum.term().utf8ToString()); + if (boost == 1.0f) { + return Explanation.match( + score, "Score based on join value " + segmentTermsEnum.term().utf8ToString()); + } else { + return Explanation.match( + score * boost, + "Score based on join value " + + segmentTermsEnum.term().utf8ToString() + + "^" + + boost); + } } } } @@ -172,9 +181,11 @@ class TermsIncludingScoreQuery extends Query implements Accountable { TermsEnum segmentTermsEnum = terms.iterator(); if (multipleValuesPerDocument) { - return new MVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost); + return new MVInOrderScorer( + this, segmentTermsEnum, context.reader().maxDoc(), cost, boost); } else { - return new SVInOrderScorer(this, segmentTermsEnum, context.reader().maxDoc(), cost); + return new SVInOrderScorer( + this, segmentTermsEnum, context.reader().maxDoc(), cost, boost); } } @@ -190,14 +201,17 @@ class TermsIncludingScoreQuery extends Query implements Accountable { final DocIdSetIterator matchingDocsIterator; final float[] scores; final long cost; + final float boost; - SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { + SVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost) + throws IOException { super(weight); FixedBitSet matchingDocs = new FixedBitSet(maxDoc); this.scores = new float[maxDoc]; fillDocsAndScores(matchingDocs, termsEnum); this.matchingDocsIterator = new BitSetIterator(matchingDocs, cost); this.cost = cost; + this.boost = boost; } protected void fillDocsAndScores(FixedBitSet matchingDocs, TermsEnum termsEnum) @@ -223,7 +237,7 @@ class TermsIncludingScoreQuery extends Query implements Accountable { @Override public float score() throws IOException { - return scores[docID()]; + return scores[docID()] * boost; } @Override @@ -246,8 +260,9 @@ class TermsIncludingScoreQuery extends Query implements Accountable { // related documents. class MVInOrderScorer extends SVInOrderScorer { - MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost) throws IOException { - super(weight, termsEnum, maxDoc, cost); + MVInOrderScorer(Weight weight, TermsEnum termsEnum, int maxDoc, long cost, float boost) + throws IOException { + super(weight, termsEnum, maxDoc, cost, boost); } @Override diff --git a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java index 94ea14dd6a3..1ae412010e3 100644 --- a/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java +++ b/lucene/join/src/test/org/apache/lucene/search/join/TestJoinUtil.java @@ -68,6 +68,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; import org.apache.lucene.tests.analysis.MockTokenizer; import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.search.QueryUtils; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.BitSet; @@ -689,6 +690,7 @@ public class TestJoinUtil extends LuceneTestCase { } } assertEquals(expectedCount, totalHits); + checkBoost(joinQuery, searcher); } searcher.getIndexReader().close(); dir.close(); @@ -997,6 +999,7 @@ public class TestJoinUtil extends LuceneTestCase { assertEquals(2, result.totalHits.value); assertEquals(0, result.scoreDocs[0].doc); assertEquals(3, result.scoreDocs[1].doc); + checkBoost(joinQuery, indexSearcher); // Score mode max. joinQuery = @@ -1011,6 +1014,7 @@ public class TestJoinUtil extends LuceneTestCase { assertEquals(2, result.totalHits.value); assertEquals(3, result.scoreDocs[0].doc); assertEquals(0, result.scoreDocs[1].doc); + checkBoost(joinQuery, indexSearcher); // Score mode total joinQuery = @@ -1025,6 +1029,7 @@ public class TestJoinUtil extends LuceneTestCase { assertEquals(2, result.totalHits.value); assertEquals(0, result.scoreDocs[0].doc); assertEquals(3, result.scoreDocs[1].doc); + checkBoost(joinQuery, indexSearcher); // Score mode avg joinQuery = @@ -1039,11 +1044,20 @@ public class TestJoinUtil extends LuceneTestCase { assertEquals(2, result.totalHits.value); assertEquals(3, result.scoreDocs[0].doc); assertEquals(0, result.scoreDocs[1].doc); + checkBoost(joinQuery, indexSearcher); indexSearcher.getIndexReader().close(); dir.close(); } + private void checkBoost(Query query, IndexSearcher searcher) throws IOException { + TopDocs result = searcher.search(query, 10); + Query boostedQuery = new BoostQuery(query, 10); + TopDocs boostedResult = searcher.search(boostedQuery, 10); + assertEquals(result.scoreDocs[0].score * 10, boostedResult.scoreDocs[0].score, 0.000001f); + QueryUtils.checkExplanations(boostedQuery, searcher); + } + public void testEquals() throws Exception { final int numDocs = atLeast(random(), 50); try (final Directory dir = newDirectory()) {