Add a specialized bulk scorer for regular conjunctions. (#12719)

PR #12382 added a bulk scorer for top-k hits on conjunctions that yielded a
significant speedup (annotation
[FP](http://people.apache.org/~mikemccand/lucenebench/AndHighHigh.html)). This
change proposes a similar change for exhaustive collection of conjunctive
queries, e.g. for counting, faceting, etc.
This commit is contained in:
Adrien Grand 2023-10-30 16:11:19 +01:00 committed by GitHub
parent 2ed60e8073
commit 58b9352cc2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 273 additions and 1 deletions

View File

@ -244,6 +244,9 @@ Optimizations
* GITHUB#12726: Return the same input vector if its a unit vector in VectorUtil#l2normalize. (Shubham Chaudhary)
* GITHUB#12719: Top-level conjunctions that are not sorted by score now have a
specialized bulk scorer. (Adrien Grand)
Changes in runtime behavior
---------------------

View File

@ -317,6 +317,12 @@ final class BooleanWeight extends Weight {
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new BlockMaxConjunctionBulkScorer(context.reader().maxDoc(), requiredScoring);
}
if (scoreMode != ScoreMode.TOP_SCORES
&& requiredScoring.size() + requiredNoScoring.size() >= 2
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
}
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
requiredScoring =
Collections.singletonList(new BlockMaxConjunctionScorer(this, requiredScoring));

View File

@ -0,0 +1,177 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.search;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.util.Bits;
/**
* BulkScorer implementation of {@link ConjunctionScorer}. For simplicity, it focuses on scorers
* that produce regular {@link DocIdSetIterator}s and not {@link TwoPhaseIterator}s.
*/
final class ConjunctionBulkScorer extends BulkScorer {
private final Scorer[] scoringScorers;
private final DocIdSetIterator lead1, lead2;
private final List<DocIdSetIterator> others;
private final Scorable scorable;
ConjunctionBulkScorer(List<Scorer> requiredScoring, List<Scorer> requiredNoScoring)
throws IOException {
final int numClauses = requiredScoring.size() + requiredNoScoring.size();
if (numClauses <= 1) {
throw new IllegalArgumentException("Expected 2 or more clauses, got " + numClauses);
}
List<Scorer> allScorers = new ArrayList<>();
allScorers.addAll(requiredScoring);
allScorers.addAll(requiredNoScoring);
this.scoringScorers = requiredScoring.toArray(Scorer[]::new);
List<DocIdSetIterator> iterators = new ArrayList<>();
for (Scorer scorer : allScorers) {
iterators.add(scorer.iterator());
}
Collections.sort(iterators, Comparator.comparingLong(DocIdSetIterator::cost));
lead1 = iterators.get(0);
lead2 = iterators.get(1);
others = List.copyOf(iterators.subList(2, iterators.size()));
scorable =
new Scorable() {
@Override
public float score() throws IOException {
double score = 0;
for (Scorer scorer : scoringScorers) {
score += scorer.score();
}
return (float) score;
}
@Override
public Collection<ChildScorable> getChildren() throws IOException {
ArrayList<ChildScorable> children = new ArrayList<>();
for (Scorer scorer : allScorers) {
children.add(new ChildScorable(scorer, "MUST"));
}
return children;
}
};
}
@Override
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
assert lead1.docID() >= lead2.docID();
if (lead1.docID() < min) {
lead1.advance(min);
}
if (lead1.docID() >= max) {
return lead1.docID();
}
collector.setScorer(scorable);
List<DocIdSetIterator> otherIterators = this.others;
DocIdSetIterator collectorIterator = collector.competitiveIterator();
if (collectorIterator != null) {
otherIterators = new ArrayList<>(otherIterators);
otherIterators.add(collectorIterator);
}
final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new);
// In the main for loop, we want to be able to rely on the invariant that lead1.docID() >
// lead2.doc(). However it's possible that these two are equal on the first document in a
// scoring window. So we treat this case separately here.
if (lead1.docID() == lead2.docID()) {
final int doc = lead1.docID();
if (acceptDocs == null || acceptDocs.get(doc)) {
boolean match = true;
for (DocIdSetIterator it : others) {
if (it.docID() < doc) {
int next = it.advance(doc);
if (next != doc) {
lead1.advance(next);
match = false;
break;
}
}
assert it.docID() == doc;
}
if (match) {
collector.collect(doc);
lead1.nextDoc();
}
} else {
lead1.nextDoc();
}
}
advanceHead:
for (int doc = lead1.docID(); doc < max; ) {
assert lead2.docID() < doc;
if (acceptDocs != null && acceptDocs.get(doc) == false) {
doc = lead1.nextDoc();
continue;
}
// We maintain the invariant that lead2.docID() < lead1.docID() so that we don't need to check
// if lead2 is already on the same doc as lead1 here.
int next2 = lead2.advance(doc);
if (next2 != doc) {
doc = lead1.advance(next2);
if (doc != next2) {
continue;
} else if (doc >= max) {
break;
} else if (acceptDocs != null && acceptDocs.get(doc) == false) {
doc = lead1.nextDoc();
continue;
}
}
assert lead2.docID() == doc;
for (DocIdSetIterator it : others) {
if (it.docID() < doc) {
int next = it.advance(doc);
if (next != doc) {
doc = lead1.advance(next);
continue advanceHead;
}
}
assert it.docID() == doc;
}
collector.collect(doc);
doc = lead1.nextDoc();
}
return lead1.docID();
}
@Override
public long cost() {
return lead1.cost();
}
}

View File

@ -36,6 +36,7 @@ import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.AssertingScorable;
import org.apache.lucene.tests.search.DisablingBulkScorerQuery;
import org.apache.lucene.tests.util.LuceneTestCase;
import org.junit.AfterClass;
import org.junit.BeforeClass;
@ -203,7 +204,8 @@ public class TestSubScorerFreqs extends LuceneTestCase {
for (final Set<String> occur : occurList) {
Map<Integer, Map<Query, Float>> docCounts =
s.search(query.build(), new CountingCollectorManager(occur));
s.search(
new DisablingBulkScorerQuery(query.build()), new CountingCollectorManager(occur));
final int maxDocs = s.getIndexReader().maxDoc();
assertEquals(maxDocs, docCounts.size());
boolean includeOptional = occur.contains("SHOULD");

View File

@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.tests.search;
import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.FilterWeight;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
/** A {@link Query} wrapper that disables bulk-scoring optimizations. */
public class DisablingBulkScorerQuery extends Query {
private final Query query;
/** Sole constructor. */
public DisablingBulkScorerQuery(Query query) {
this.query = query;
}
@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Query rewritten = query.rewrite(indexSearcher);
if (query != rewritten) {
return new DisablingBulkScorerQuery(rewritten);
}
return super.rewrite(indexSearcher);
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
Weight in = query.createWeight(searcher, scoreMode, boost);
return new FilterWeight(in) {
@Override
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
Scorer scorer = scorer(context);
if (scorer == null) {
return null;
}
return new DefaultBulkScorer(scorer);
}
};
}
@Override
public String toString(String field) {
return query.toString(field);
}
@Override
public void visit(QueryVisitor visitor) {
query.visit(visitor);
}
@Override
public boolean equals(Object obj) {
return sameClassAs(obj) && query.equals(((DisablingBulkScorerQuery) obj).query);
}
@Override
public int hashCode() {
return 31 * classHash() + query.hashCode();
}
}