mirror of https://github.com/apache/lucene.git
Add a specialized bulk scorer for regular conjunctions. (#12719)
PR #12382 added a bulk scorer for top-k hits on conjunctions that yielded a significant speedup (annotation [FP](http://people.apache.org/~mikemccand/lucenebench/AndHighHigh.html)). This change proposes a similar change for exhaustive collection of conjunctive queries, e.g. for counting, faceting, etc.
This commit is contained in:
parent
2ed60e8073
commit
58b9352cc2
|
@ -244,6 +244,9 @@ Optimizations
|
|||
|
||||
* GITHUB#12726: Return the same input vector if its a unit vector in VectorUtil#l2normalize. (Shubham Chaudhary)
|
||||
|
||||
* GITHUB#12719: Top-level conjunctions that are not sorted by score now have a
|
||||
specialized bulk scorer. (Adrien Grand)
|
||||
|
||||
Changes in runtime behavior
|
||||
---------------------
|
||||
|
||||
|
|
|
@ -317,6 +317,12 @@ final class BooleanWeight extends Weight {
|
|||
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
|
||||
return new BlockMaxConjunctionBulkScorer(context.reader().maxDoc(), requiredScoring);
|
||||
}
|
||||
if (scoreMode != ScoreMode.TOP_SCORES
|
||||
&& requiredScoring.size() + requiredNoScoring.size() >= 2
|
||||
&& requiredScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)
|
||||
&& requiredNoScoring.stream().map(Scorer::twoPhaseIterator).allMatch(Objects::isNull)) {
|
||||
return new ConjunctionBulkScorer(requiredScoring, requiredNoScoring);
|
||||
}
|
||||
if (scoreMode == ScoreMode.TOP_SCORES && requiredScoring.size() > 1) {
|
||||
requiredScoring =
|
||||
Collections.singletonList(new BlockMaxConjunctionScorer(this, requiredScoring));
|
||||
|
|
|
@ -0,0 +1,177 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.util.Bits;
|
||||
|
||||
/**
|
||||
* BulkScorer implementation of {@link ConjunctionScorer}. For simplicity, it focuses on scorers
|
||||
* that produce regular {@link DocIdSetIterator}s and not {@link TwoPhaseIterator}s.
|
||||
*/
|
||||
final class ConjunctionBulkScorer extends BulkScorer {
|
||||
|
||||
private final Scorer[] scoringScorers;
|
||||
private final DocIdSetIterator lead1, lead2;
|
||||
private final List<DocIdSetIterator> others;
|
||||
private final Scorable scorable;
|
||||
|
||||
ConjunctionBulkScorer(List<Scorer> requiredScoring, List<Scorer> requiredNoScoring)
|
||||
throws IOException {
|
||||
final int numClauses = requiredScoring.size() + requiredNoScoring.size();
|
||||
if (numClauses <= 1) {
|
||||
throw new IllegalArgumentException("Expected 2 or more clauses, got " + numClauses);
|
||||
}
|
||||
List<Scorer> allScorers = new ArrayList<>();
|
||||
allScorers.addAll(requiredScoring);
|
||||
allScorers.addAll(requiredNoScoring);
|
||||
|
||||
this.scoringScorers = requiredScoring.toArray(Scorer[]::new);
|
||||
List<DocIdSetIterator> iterators = new ArrayList<>();
|
||||
for (Scorer scorer : allScorers) {
|
||||
iterators.add(scorer.iterator());
|
||||
}
|
||||
Collections.sort(iterators, Comparator.comparingLong(DocIdSetIterator::cost));
|
||||
lead1 = iterators.get(0);
|
||||
lead2 = iterators.get(1);
|
||||
others = List.copyOf(iterators.subList(2, iterators.size()));
|
||||
scorable =
|
||||
new Scorable() {
|
||||
@Override
|
||||
public float score() throws IOException {
|
||||
double score = 0;
|
||||
for (Scorer scorer : scoringScorers) {
|
||||
score += scorer.score();
|
||||
}
|
||||
return (float) score;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Collection<ChildScorable> getChildren() throws IOException {
|
||||
ArrayList<ChildScorable> children = new ArrayList<>();
|
||||
for (Scorer scorer : allScorers) {
|
||||
children.add(new ChildScorable(scorer, "MUST"));
|
||||
}
|
||||
return children;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int score(LeafCollector collector, Bits acceptDocs, int min, int max) throws IOException {
|
||||
assert lead1.docID() >= lead2.docID();
|
||||
|
||||
if (lead1.docID() < min) {
|
||||
lead1.advance(min);
|
||||
}
|
||||
|
||||
if (lead1.docID() >= max) {
|
||||
return lead1.docID();
|
||||
}
|
||||
|
||||
collector.setScorer(scorable);
|
||||
|
||||
List<DocIdSetIterator> otherIterators = this.others;
|
||||
DocIdSetIterator collectorIterator = collector.competitiveIterator();
|
||||
if (collectorIterator != null) {
|
||||
otherIterators = new ArrayList<>(otherIterators);
|
||||
otherIterators.add(collectorIterator);
|
||||
}
|
||||
|
||||
final DocIdSetIterator[] others = otherIterators.toArray(DocIdSetIterator[]::new);
|
||||
|
||||
// In the main for loop, we want to be able to rely on the invariant that lead1.docID() >
|
||||
// lead2.doc(). However it's possible that these two are equal on the first document in a
|
||||
// scoring window. So we treat this case separately here.
|
||||
if (lead1.docID() == lead2.docID()) {
|
||||
final int doc = lead1.docID();
|
||||
if (acceptDocs == null || acceptDocs.get(doc)) {
|
||||
boolean match = true;
|
||||
for (DocIdSetIterator it : others) {
|
||||
if (it.docID() < doc) {
|
||||
int next = it.advance(doc);
|
||||
if (next != doc) {
|
||||
lead1.advance(next);
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert it.docID() == doc;
|
||||
}
|
||||
|
||||
if (match) {
|
||||
collector.collect(doc);
|
||||
lead1.nextDoc();
|
||||
}
|
||||
} else {
|
||||
lead1.nextDoc();
|
||||
}
|
||||
}
|
||||
|
||||
advanceHead:
|
||||
for (int doc = lead1.docID(); doc < max; ) {
|
||||
assert lead2.docID() < doc;
|
||||
|
||||
if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
doc = lead1.nextDoc();
|
||||
continue;
|
||||
}
|
||||
|
||||
// We maintain the invariant that lead2.docID() < lead1.docID() so that we don't need to check
|
||||
// if lead2 is already on the same doc as lead1 here.
|
||||
int next2 = lead2.advance(doc);
|
||||
if (next2 != doc) {
|
||||
doc = lead1.advance(next2);
|
||||
if (doc != next2) {
|
||||
continue;
|
||||
} else if (doc >= max) {
|
||||
break;
|
||||
} else if (acceptDocs != null && acceptDocs.get(doc) == false) {
|
||||
doc = lead1.nextDoc();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
assert lead2.docID() == doc;
|
||||
|
||||
for (DocIdSetIterator it : others) {
|
||||
if (it.docID() < doc) {
|
||||
int next = it.advance(doc);
|
||||
if (next != doc) {
|
||||
doc = lead1.advance(next);
|
||||
continue advanceHead;
|
||||
}
|
||||
}
|
||||
assert it.docID() == doc;
|
||||
}
|
||||
|
||||
collector.collect(doc);
|
||||
doc = lead1.nextDoc();
|
||||
}
|
||||
|
||||
return lead1.docID();
|
||||
}
|
||||
|
||||
@Override
|
||||
public long cost() {
|
||||
return lead1.cost();
|
||||
}
|
||||
}
|
|
@ -36,6 +36,7 @@ import org.apache.lucene.store.Directory;
|
|||
import org.apache.lucene.tests.analysis.MockAnalyzer;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.search.AssertingScorable;
|
||||
import org.apache.lucene.tests.search.DisablingBulkScorerQuery;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
|
@ -203,7 +204,8 @@ public class TestSubScorerFreqs extends LuceneTestCase {
|
|||
|
||||
for (final Set<String> occur : occurList) {
|
||||
Map<Integer, Map<Query, Float>> docCounts =
|
||||
s.search(query.build(), new CountingCollectorManager(occur));
|
||||
s.search(
|
||||
new DisablingBulkScorerQuery(query.build()), new CountingCollectorManager(occur));
|
||||
final int maxDocs = s.getIndexReader().maxDoc();
|
||||
assertEquals(maxDocs, docCounts.size());
|
||||
boolean includeOptional = occur.contains("SHOULD");
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.lucene.tests.search;
|
||||
|
||||
import java.io.IOException;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.search.BulkScorer;
|
||||
import org.apache.lucene.search.FilterWeight;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.QueryVisitor;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.Scorer;
|
||||
import org.apache.lucene.search.Weight;
|
||||
|
||||
/** A {@link Query} wrapper that disables bulk-scoring optimizations. */
|
||||
public class DisablingBulkScorerQuery extends Query {
|
||||
|
||||
private final Query query;
|
||||
|
||||
/** Sole constructor. */
|
||||
public DisablingBulkScorerQuery(Query query) {
|
||||
this.query = query;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
|
||||
Query rewritten = query.rewrite(indexSearcher);
|
||||
if (query != rewritten) {
|
||||
return new DisablingBulkScorerQuery(rewritten);
|
||||
}
|
||||
return super.rewrite(indexSearcher);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
Weight in = query.createWeight(searcher, scoreMode, boost);
|
||||
return new FilterWeight(in) {
|
||||
@Override
|
||||
public BulkScorer bulkScorer(LeafReaderContext context) throws IOException {
|
||||
Scorer scorer = scorer(context);
|
||||
if (scorer == null) {
|
||||
return null;
|
||||
}
|
||||
return new DefaultBulkScorer(scorer);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString(String field) {
|
||||
return query.toString(field);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visit(QueryVisitor visitor) {
|
||||
query.visit(visitor);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object obj) {
|
||||
return sameClassAs(obj) && query.equals(((DisablingBulkScorerQuery) obj).query);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * classHash() + query.hashCode();
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue