Make inlining decisions a bit more predictable in our main queries. (#14023)

This implements a small contained hack to make sure that our compound scorers
like `MaxScoreBulkScorer`, `ConjunctionBulkScorer`,
`BlockMaxConjunctionBulkScorer`, `WANDScorer` and `ConjunctionDISI` only have
two concrete implementations of `DocIdSetIterator` and `Scorable` to deal with.

This helps because it makes calls to `DocIdSetIterator#nextDoc()`,
`DocIdSetIterator#advance(int)` and `Scorable#score()` bimorphic at most, and
bimorphic calls are candidate for inlining.

This should help speed up boolean queries of term queries at the expense of
boolean queries of other query types. This feels fair to me as it gives more
speedups than slowdowns in benchmarks, and that boolean queries of term queries
are extremely typical. Boolean queries that mix term queries and other types of
queries may get a slowdown or a speedup depending on whether they get more from
the speedup on their term clauses than they lose on their other clauses.
This commit is contained in:
Adrien Grand 2024-11-29 13:27:49 +01:00 committed by GitHub
parent 70530a92d9
commit f9869b54d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 175 additions and 39 deletions

View File

@ -116,6 +116,9 @@ Optimizations
* GITHUB#14014: Filtered disjunctions now get executed via `MaxScoreBulkScorer`.
(Adrien Grand)
* GITHUB#14023: Make JVM inlining decisions more predictable in our main
queries. (Adrien Grand)
Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended

View File

@ -38,7 +38,7 @@ final class BlockMaxConjunctionBulkScorer extends BulkScorer {
private final Scorer[] scorers;
private final DocIdSetIterator[] iterators;
private final DocIdSetIterator lead1, lead2;
private final Scorer scorer1, scorer2;
private final Scorable scorer1, scorer2;
private final DocAndScore scorable = new DocAndScore();
private final double[] sumOfOtherClauses;
private final int maxDoc;
@ -51,10 +51,10 @@ final class BlockMaxConjunctionBulkScorer extends BulkScorer {
Arrays.sort(this.scorers, Comparator.comparingLong(scorer -> scorer.iterator().cost()));
this.iterators =
Arrays.stream(this.scorers).map(Scorer::iterator).toArray(DocIdSetIterator[]::new);
lead1 = iterators[0];
lead2 = iterators[1];
scorer1 = this.scorers[0];
scorer2 = this.scorers[1];
lead1 = ScorerUtil.likelyImpactsEnum(iterators[0]);
lead2 = ScorerUtil.likelyImpactsEnum(iterators[1]);
scorer1 = ScorerUtil.likelyTermScorer(this.scorers[0]);
scorer2 = ScorerUtil.likelyTermScorer(this.scorers[1]);
this.sumOfOtherClauses = new double[this.scorers.length];
for (int i = 0; i < sumOfOtherClauses.length; i++) {
sumOfOtherClauses[i] = Double.POSITIVE_INFINITY;

View File

@ -29,6 +29,7 @@ import java.util.List;
*/
final class BlockMaxConjunctionScorer extends Scorer {
final Scorer[] scorers;
final Scorable[] scorables;
final DocIdSetIterator[] approximations;
final TwoPhaseIterator[] twoPhases;
float minScore;
@ -38,6 +39,8 @@ final class BlockMaxConjunctionScorer extends Scorer {
this.scorers = scorersList.toArray(new Scorer[scorersList.size()]);
// Sort scorer by cost
Arrays.sort(this.scorers, Comparator.comparingLong(s -> s.iterator().cost()));
this.scorables =
Arrays.stream(scorers).map(ScorerUtil::likelyTermScorer).toArray(Scorable[]::new);
this.approximations = new DocIdSetIterator[scorers.length];
List<TwoPhaseIterator> twoPhaseList = new ArrayList<>();
@ -50,6 +53,7 @@ final class BlockMaxConjunctionScorer extends Scorer {
} else {
approximations[i] = scorer.iterator();
}
approximations[i] = ScorerUtil.likelyImpactsEnum(approximations[i]);
scorer.advanceShallow(0);
}
this.twoPhases = twoPhaseList.toArray(new TwoPhaseIterator[twoPhaseList.size()]);
@ -207,7 +211,7 @@ final class BlockMaxConjunctionScorer extends Scorer {
@Override
public float score() throws IOException {
double score = 0;
for (Scorer scorer : scorers) {
for (Scorable scorer : scorables) {
score += scorer.score();
}
return (float) score;

View File

@ -155,7 +155,7 @@ final class BooleanScorer extends BulkScorer {
this.needsScores = needsScores;
LongArrayList costs = new LongArrayList(scorers.size());
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
DisiWrapper w = new DisiWrapper(scorer, false);
costs.add(w.cost);
final DisiWrapper evicted = tail.insertWithOverflow(w);
if (evicted != null) {
@ -177,7 +177,7 @@ final class BooleanScorer extends BulkScorer {
Bucket[] buckets = BooleanScorer.this.buckets;
DocIdSetIterator it = w.iterator;
Scorer scorer = w.scorer;
Scorable scorer = w.scorable;
int doc = w.doc;
if (doc < min) {
doc = it.advance(min);

View File

@ -30,7 +30,7 @@ import org.apache.lucene.util.Bits;
*/
final class ConjunctionBulkScorer extends BulkScorer {
private final Scorer[] scoringScorers;
private final Scorable[] scoringScorers;
private final DocIdSetIterator lead1, lead2;
private final List<DocIdSetIterator> others;
private final Scorable scorable;
@ -45,7 +45,8 @@ final class ConjunctionBulkScorer extends BulkScorer {
allScorers.addAll(requiredScoring);
allScorers.addAll(requiredNoScoring);
this.scoringScorers = requiredScoring.toArray(Scorer[]::new);
this.scoringScorers =
requiredScoring.stream().map(ScorerUtil::likelyTermScorer).toArray(Scorable[]::new);
List<DocIdSetIterator> iterators = new ArrayList<>();
for (Scorer scorer : allScorers) {
iterators.add(scorer.iterator());
@ -59,7 +60,7 @@ final class ConjunctionBulkScorer extends BulkScorer {
@Override
public float score() throws IOException {
double score = 0;
for (Scorer scorer : scoringScorers) {
for (Scorable scorer : scoringScorers) {
score += scorer.score();
}
return (float) score;

View File

@ -16,6 +16,8 @@
*/
package org.apache.lucene.search;
import java.util.Objects;
/**
* Wrapper used in {@link DisiPriorityQueue}.
*
@ -24,6 +26,7 @@ package org.apache.lucene.search;
public class DisiWrapper {
public final DocIdSetIterator iterator;
public final Scorer scorer;
public final Scorable scorable;
public final long cost;
public final float matchCost; // the match cost for two-phase iterators, 0 otherwise
public int doc; // the current doc, used for comparison
@ -42,9 +45,14 @@ public class DisiWrapper {
// for MaxScoreBulkScorer
float maxWindowScore;
public DisiWrapper(Scorer scorer) {
this.scorer = scorer;
this.iterator = scorer.iterator();
public DisiWrapper(Scorer scorer, boolean impacts) {
this.scorer = Objects.requireNonNull(scorer);
this.scorable = ScorerUtil.likelyTermScorer(scorer);
if (impacts) {
this.iterator = ScorerUtil.likelyImpactsEnum(scorer.iterator());
} else {
this.iterator = scorer.iterator();
}
this.cost = iterator.cost();
this.doc = -1;
this.twoPhaseView = scorer.twoPhaseIterator();

View File

@ -60,7 +60,7 @@ final class DisjunctionMaxScorer extends DisjunctionScorer {
float scoreMax = 0;
double otherScoreSum = 0;
for (DisiWrapper w = topList; w != null; w = w.next) {
float subScore = w.scorer.score();
float subScore = w.scorable.score();
if (subScore >= scoreMax) {
otherScoreSum += scoreMax;
scoreMax = subScore;

View File

@ -37,7 +37,7 @@ abstract class DisjunctionScorer extends Scorer {
}
this.subScorers = new DisiPriorityQueue(subScorers.size());
for (Scorer scorer : subScorers) {
final DisiWrapper w = new DisiWrapper(scorer);
final DisiWrapper w = new DisiWrapper(scorer, false);
this.subScorers.add(w);
}
this.needsScores = scoreMode != ScoreMode.COMPLETE_NO_SCORES;

View File

@ -40,7 +40,7 @@ final class DisjunctionSumScorer extends DisjunctionScorer {
double score = 0;
for (DisiWrapper w = topList; w != null; w = w.next) {
score += w.scorer.score();
score += w.scorable.score();
}
return (float) score;
}

View File

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
/** Wrapper around a {@link DocIdSetIterator}. */
public class FilterDocIdSetIterator extends DocIdSetIterator {
/** Wrapped instance. */
protected final DocIdSetIterator in;
/** Sole constructor. */
public FilterDocIdSetIterator(DocIdSetIterator in) {
this.in = in;
}
@Override
public int docID() {
return in.docID();
}
@Override
public int nextDoc() throws IOException {
return in.nextDoc();
}
@Override
public int advance(int target) throws IOException {
return in.advance(target);
}
@Override
public long cost() {
return in.cost();
}
}

View File

@ -35,7 +35,7 @@ public abstract class IndriDisjunctionScorer extends IndriScorer {
this.subScorersList = subScorersList;
this.subScorers = new DisiPriorityQueue(subScorersList.size());
for (Scorer scorer : subScorersList) {
final DisiWrapper w = new DisiWrapper(scorer);
final DisiWrapper w = new DisiWrapper(scorer, false);
this.subScorers.add(w);
}
this.approximation = new DisjunctionDISIApproximation(this.subScorers);

View File

@ -53,13 +53,13 @@ final class MaxScoreBulkScorer extends BulkScorer {
MaxScoreBulkScorer(int maxDoc, List<Scorer> scorers, Scorer filter) throws IOException {
this.maxDoc = maxDoc;
this.filter = filter == null ? null : new DisiWrapper(filter);
this.filter = filter == null ? null : new DisiWrapper(filter, false);
allScorers = new DisiWrapper[scorers.size()];
scratch = new DisiWrapper[allScorers.length];
int i = 0;
long cost = 0;
for (Scorer scorer : scorers) {
DisiWrapper w = new DisiWrapper(scorer);
DisiWrapper w = new DisiWrapper(scorer, true);
cost += w.cost;
allScorers[i++] = w;
}
@ -221,7 +221,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
if (acceptDocs != null && acceptDocs.get(doc) == false) {
continue;
}
scoreNonEssentialClauses(collector, doc, top.scorer.score(), firstEssentialScorer);
scoreNonEssentialClauses(collector, doc, top.scorable.score(), firstEssentialScorer);
}
top.doc = top.iterator.docID();
essentialQueue.updateTop();
@ -249,7 +249,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
continue;
}
double score = lead1.scorer.score();
double score = lead1.scorable.score();
// We specialize handling the second best scorer, which seems to help a bit with performance.
// But this is the exact same logic as in the below for loop.
@ -268,7 +268,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
continue;
}
score += lead2.scorer.score();
score += lead2.scorable.score();
for (int i = allScorers.length - 3; i >= firstRequiredScorer; --i) {
if ((float) MathUtil.sumUpperBound(score + maxScoreSums[i], allScorers.length)
@ -286,7 +286,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
lead1.doc = lead1.iterator.advance(Math.min(w.doc, max));
continue outer;
}
score += w.scorer.score();
score += w.scorable.score();
}
scoreNonEssentialClauses(collector, lead1.doc, score, firstRequiredScorer);
@ -307,7 +307,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
if (acceptDocs == null || acceptDocs.get(doc)) {
final int i = doc - innerWindowMin;
windowMatches[i >>> 6] |= 1L << i;
windowScores[i] += top.scorer.score();
windowScores[i] += top.scorable.score();
}
}
top.doc = top.iterator.docID();
@ -399,7 +399,7 @@ final class MaxScoreBulkScorer extends BulkScorer {
scorer.doc = scorer.iterator.advance(doc);
}
if (scorer.doc == doc) {
score += scorer.scorer.score();
score += scorer.scorable.score();
}
}

View File

@ -113,10 +113,10 @@ final class MultiTermQueryConstantScoreBlendedWrapper<Q extends MultiTermQuery>
DisiPriorityQueue subs = new DisiPriorityQueue(highFrequencyTerms.size() + 1);
for (DocIdSetIterator disi : highFrequencyTerms) {
Scorer s = wrapWithDummyScorer(this, disi);
subs.add(new DisiWrapper(s));
subs.add(new DisiWrapper(s, false));
}
Scorer s = wrapWithDummyScorer(this, otherTerms.build().iterator());
subs.add(new DisiWrapper(s));
subs.add(new DisiWrapper(s, false));
return new WeightOrDocIdSetIterator(new DisjunctionDISIApproximation(subs));
}

View File

@ -16,12 +16,48 @@
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.stream.LongStream;
import java.util.stream.StreamSupport;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.IndexWriterConfig;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.store.ByteBuffersDirectory;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.PriorityQueue;
/** Util class for Scorer related methods */
class ScorerUtil {
private static final Class<?> DEFAULT_IMPACTS_ENUM_CLASS;
static {
try (Directory dir = new ByteBuffersDirectory();
IndexWriter w = new IndexWriter(dir, new IndexWriterConfig())) {
Document doc = new Document();
doc.add(new FeatureField("field", "value", 1f));
w.addDocument(doc);
try (DirectoryReader reader = DirectoryReader.open(w)) {
LeafReader leafReader = reader.leaves().get(0).reader();
TermsEnum te = leafReader.terms("field").iterator();
if (te.seekExact(new BytesRef("value")) == false) {
throw new Error();
}
ImpactsEnum ie = te.impacts(PostingsEnum.FREQS);
DEFAULT_IMPACTS_ENUM_CLASS = ie.getClass();
}
} catch (IOException e) {
throw new Error(e);
}
}
static long costWithMinShouldMatch(LongStream costs, int numScorers, int minShouldMatch) {
// the idea here is the following: a boolean query c1,c2,...cn with minShouldMatch=m
// could be rewritten to:
@ -46,4 +82,30 @@ class ScorerUtil {
costs.forEach(pq::insertWithOverflow);
return StreamSupport.stream(pq.spliterator(), false).mapToLong(Number::longValue).sum();
}
/**
* Optimize a {@link DocIdSetIterator} for the case when it is likely implemented via an {@link
* ImpactsEnum}. This return method only has 2 possible return types, which helps make sure that
* calls to {@link DocIdSetIterator#nextDoc()} and {@link DocIdSetIterator#advance(int)} are
* bimorphic at most and candidate for inlining.
*/
static DocIdSetIterator likelyImpactsEnum(DocIdSetIterator it) {
if (it.getClass() != DEFAULT_IMPACTS_ENUM_CLASS
&& it.getClass() != FilterDocIdSetIterator.class) {
it = new FilterDocIdSetIterator(it);
}
return it;
}
/**
* Optimize a {@link Scorable} for the case when it is likely implemented via a {@link
* TermScorer}. This return method only has 2 possible return types, which helps make sure that
* calls to {@link Scorable#score()} are bimorphic at most and candidate for inlining.
*/
static Scorable likelyTermScorer(Scorable scorable) {
if (scorable.getClass() != TermScorer.class && scorable.getClass() != FilterScorable.class) {
scorable = new FilterScorable(scorable);
}
return scorable;
}
}

View File

@ -646,7 +646,7 @@ public final class SynonymQuery extends Query {
final float boost;
DisiWrapperFreq(Scorer scorer, float boost) {
super(scorer);
super(scorer, false);
this.pe = (PostingsEnum) scorer.iterator();
this.boost = boost;
}

View File

@ -196,7 +196,12 @@ final class WANDScorer extends Scorer {
}
for (Scorer scorer : scorers) {
addUnpositionedLead(new DisiWrapper(scorer));
// Ideally we would pass true when scoreMode == TOP_SCORES and false otherwise, but this would
// break the optimization as there could then be 3 different impls of DocIdSetIterator
// (ImpactsEnum, PostingsEnum and <Else>). So we pass true to favor disjunctions sorted by
// descending score as opposed to non-scoring disjunctions whose minShouldMatch is greater
// than 1.
addUnpositionedLead(new DisiWrapper(scorer, true));
}
this.cost =
@ -221,7 +226,7 @@ final class WANDScorer extends Scorer {
List<Float> leadScores = new ArrayList<>();
for (DisiWrapper w = lead; w != null; w = w.next) {
assert w.doc == doc;
leadScores.add(w.scorer.score());
leadScores.add(w.scorable.score());
}
// Make sure to recompute the sum in the same order to get the same floating point rounding
// errors.
@ -370,7 +375,7 @@ final class WANDScorer extends Scorer {
this.lead = lead;
freq += 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore += lead.scorer.score();
leadScore += lead.scorable.score();
}
}
@ -522,7 +527,7 @@ final class WANDScorer extends Scorer {
lead.next = null;
freq = 1;
if (scoreMode == ScoreMode.TOP_SCORES) {
leadScore = lead.scorer.score();
leadScore = lead.scorable.score();
}
while (head.size() > 0 && head.top().doc == doc) {
addLead(head.pop());
@ -553,7 +558,7 @@ final class WANDScorer extends Scorer {
if (scoreMode != ScoreMode.TOP_SCORES) {
// With TOP_SCORES, the score was already computed on the fly.
for (DisiWrapper s = lead; s != null; s = s.next) {
leadScore += s.scorer.score();
leadScore += s.scorable.score();
}
}
return (float) leadScore;

View File

@ -70,7 +70,7 @@ public class TestDisiPriorityQueue extends LuceneTestCase {
private static DisiWrapper wrapper(DocIdSetIterator iterator) throws IOException {
Query q = new DummyQuery(iterator);
Scorer s = q.createWeight(null, ScoreMode.COMPLETE_NO_SCORES, 1.0f).scorer(null);
return new DisiWrapper(s);
return new DisiWrapper(s, random().nextBoolean());
}
private static DocIdSetIterator randomDisi(Random r) {

View File

@ -422,15 +422,17 @@ public final class CombinedFieldQuery extends Query implements Accountable {
}
private static class WeightedDisiWrapper extends DisiWrapper {
final PostingsEnum postingsEnum;
final float weight;
WeightedDisiWrapper(Scorer scorer, float weight) {
super(scorer);
super(scorer, false);
this.weight = weight;
this.postingsEnum = (PostingsEnum) scorer.iterator();
}
float freq() throws IOException {
return weight * ((PostingsEnum) iterator).freq();
return weight * postingsEnum.freq();
}
}

View File

@ -54,7 +54,7 @@ final class CoveringScorer extends Scorer {
subScorers = new DisiPriorityQueue(scorers.size());
for (Scorer scorer : scorers) {
subScorers.add(new DisiWrapper(scorer));
subScorers.add(new DisiWrapper(scorer, false));
}
this.cost = scorers.stream().map(Scorer::iterator).mapToLong(DocIdSetIterator::cost).sum();
@ -210,7 +210,7 @@ final class CoveringScorer extends Scorer {
setTopListAndFreqIfNecessary();
double score = 0;
for (DisiWrapper w = topList; w != null; w = w.next) {
score += w.scorer.score();
score += w.scorable.score();
}
return (float) score;
}