Ensure #finish is called on all drill-sideways FacetCollectors even when no hits are scored (#12853)

This commit is contained in:
Greg Miller 2023-12-08 15:25:57 -08:00 committed by GitHub
parent fb269c9e64
commit a9b5ef4749
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 177 additions and 77 deletions

View File

@ -186,9 +186,11 @@ Optimizations
Bug Fixes
---------------------
* GITHUB#12866: Prevent extra similarity computation for single-level HNSW graphs. (Kaival Parikh)
* GITHUB#12558: Ensure #finish is called on all drill-sideways FacetsCollectors even when no hits are scored.
(Greg Miller)
Other
---------------------

View File

@ -137,6 +137,14 @@ public class DrillSideways {
return new FacetsCollectorManager();
}
/**
* Subclass can override to customize drill sideways facets collector. This should not return
* {@code null} as we assume drill sideways is being used to collect "sideways" hits:
*/
protected FacetsCollectorManager createDrillSidewaysFacetsCollectorManager() {
return new FacetsCollectorManager();
}
/** Subclass can override to customize per-dim Facets impl. */
protected Facets buildFacetsResult(
FacetsCollector drillDowns, FacetsCollector[] drillSideways, String[] drillSidewaysDims)
@ -397,7 +405,7 @@ public class DrillSideways {
FacetsCollectorManager[] drillSidewaysFacetsCollectorManagers =
new FacetsCollectorManager[numDims];
for (int i = 0; i < numDims; i++) {
drillSidewaysFacetsCollectorManagers[i] = new FacetsCollectorManager();
drillSidewaysFacetsCollectorManagers[i] = createDrillSidewaysFacetsCollectorManager();
}
DrillSidewaysQuery dsq =
@ -467,7 +475,10 @@ public class DrillSideways {
for (String dim : drillDownDims.keySet())
callableCollectors.add(
new CallableCollector(
i++, searcher, getDrillDownQuery(query, filters, dim), new FacetsCollectorManager()));
i++,
searcher,
getDrillDownQuery(query, filters, dim),
createDrillSidewaysFacetsCollectorManager()));
final FacetsCollector mainFacetsCollector;
final FacetsCollector[] facetsCollectors = new FacetsCollector[drillDownDims.size()];

View File

@ -23,13 +23,13 @@ import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Objects;
import org.apache.lucene.facet.DrillSidewaysScorer.DocsAndCost;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
@ -175,6 +175,17 @@ class DrillSidewaysQuery extends Query {
int drillDownCount = drillDowns.length;
FacetsCollector drillDownCollector;
LeafCollector drillDownLeafCollector;
if (drillDownCollectorManager != null) {
drillDownCollector = drillDownCollectorManager.newCollector();
managedDrillDownCollectors.add(drillDownCollector);
drillDownLeafCollector = drillDownCollector.getLeafCollector(context);
} else {
drillDownCollector = null;
drillDownLeafCollector = null;
}
FacetsCollector[] sidewaysCollectors = new FacetsCollector[drillDownCount];
managedDrillSidewaysCollectors.add(sidewaysCollectors);
@ -193,42 +204,29 @@ class DrillSidewaysQuery extends Query {
FacetsCollector sidewaysCollector = drillSidewaysCollectorManagers[dim].newCollector();
sidewaysCollectors[dim] = sidewaysCollector;
dims[dim] = new DrillSidewaysScorer.DocsAndCost(scorer, sidewaysCollector);
dims[dim] =
new DrillSidewaysScorer.DocsAndCost(
scorer, sidewaysCollector.getLeafCollector(context));
}
// If more than one dim has no matches, then there
// are no hits nor drill-sideways counts. Or, if we
// have only one dim and that dim has no matches,
// same thing.
// if (nullCount > 1 || (nullCount == 1 && dims.length == 1)) {
if (nullCount > 1) {
// If baseScorer is null or the dim nullCount > 1, then we have nothing to score. We return
// a null scorer in this case, but we need to make sure #finish gets called on all facet
// collectors since IndexSearcher won't handle this for us:
if (baseScorer == null || nullCount > 1) {
if (drillDownCollector != null) {
drillDownCollector.finish();
}
for (FacetsCollector fc : sidewaysCollectors) {
fc.finish();
}
return null;
}
// Sort drill-downs by most restrictive first:
Arrays.sort(
dims,
new Comparator<DrillSidewaysScorer.DocsAndCost>() {
@Override
public int compare(DocsAndCost o1, DocsAndCost o2) {
return Long.compare(o1.approximation.cost(), o2.approximation.cost());
}
});
if (baseScorer == null) {
return null;
}
FacetsCollector drillDownCollector;
if (drillDownCollectorManager != null) {
drillDownCollector = drillDownCollectorManager.newCollector();
managedDrillDownCollectors.add(drillDownCollector);
} else {
drillDownCollector = null;
}
Arrays.sort(dims, Comparator.comparingLong(o -> o.approximation.cost()));
return new DrillSidewaysScorer(
context, baseScorer, drillDownCollector, dims, scoreSubDocsAtOnce);
context, baseScorer, drillDownLeafCollector, dims, scoreSubDocsAtOnce);
}
};
}

View File

@ -24,7 +24,6 @@ import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.BulkScorer;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Scorable;
@ -46,8 +45,7 @@ class DrillSidewaysScorer extends BulkScorer {
// private static boolean DEBUG = false;
private final Collector drillDownCollector;
private LeafCollector drillDownLeafCollector;
private final LeafCollector drillDownLeafCollector;
private final DocsAndCost[] dims;
@ -70,7 +68,7 @@ class DrillSidewaysScorer extends BulkScorer {
DrillSidewaysScorer(
LeafReaderContext context,
Scorer baseScorer,
Collector drillDownCollector,
LeafCollector drillDownLeafCollector,
DocsAndCost[] dims,
boolean scoreSubDocsAtOnce) {
this.dims = dims;
@ -83,7 +81,7 @@ class DrillSidewaysScorer extends BulkScorer {
} else {
this.baseApproximation = baseIterator;
}
this.drillDownCollector = drillDownCollector;
this.drillDownLeafCollector = drillDownLeafCollector;
this.scoreSubDocsAtOnce = scoreSubDocsAtOnce;
}
@ -101,18 +99,6 @@ class DrillSidewaysScorer extends BulkScorer {
if (maxDoc != Integer.MAX_VALUE) {
throw new IllegalArgumentException("maxDoc must be Integer.MAX_VALUE");
}
// if (DEBUG) {
// System.out.println("\nscore: reader=" + context.reader());
// }
// System.out.println("score r=" + context.reader());
if (drillDownCollector != null) {
drillDownLeafCollector = drillDownCollector.getLeafCollector(context);
} else {
drillDownLeafCollector = null;
}
for (DocsAndCost dim : dims) {
dim.sidewaysLeafCollector = dim.sidewaysCollector.getLeafCollector(context);
}
// some scorers, eg ReqExlScorer, can hit NPE if cost is called after nextDoc
long baseQueryCost = baseIterator.cost();
@ -723,7 +709,7 @@ class DrillSidewaysScorer extends BulkScorer {
// }
collector.collect(collectDocID);
if (drillDownCollector != null) {
if (drillDownLeafCollector != null) {
drillDownLeafCollector.collect(collectDocID);
}
@ -739,7 +725,7 @@ class DrillSidewaysScorer extends BulkScorer {
private void collectHit(LeafCollector collector, DocsAndCost dim) throws IOException {
collector.collect(collectDocID);
if (drillDownCollector != null) {
if (drillDownLeafCollector != null) {
drillDownLeafCollector.collect(collectDocID);
}
@ -749,7 +735,7 @@ class DrillSidewaysScorer extends BulkScorer {
private void collectHit(LeafCollector collector, List<DocsAndCost> dims) throws IOException {
collector.collect(collectDocID);
if (drillDownCollector != null) {
if (drillDownLeafCollector != null) {
drillDownLeafCollector.collect(collectDocID);
}
@ -808,10 +794,9 @@ class DrillSidewaysScorer extends BulkScorer {
// two-phase confirmation, or null if the approximation is accurate
final TwoPhaseIterator twoPhase;
final float matchCost;
final Collector sidewaysCollector;
LeafCollector sidewaysLeafCollector;
final LeafCollector sidewaysLeafCollector;
DocsAndCost(Scorer scorer, Collector sidewaysCollector) {
DocsAndCost(Scorer scorer, LeafCollector sidewaysLeafCollector) {
final TwoPhaseIterator twoPhase = scorer.twoPhaseIterator();
if (twoPhase == null) {
this.approximation = scorer.iterator();
@ -823,7 +808,7 @@ class DrillSidewaysScorer extends BulkScorer {
this.matchCost = twoPhase.matchCost();
}
this.cost = approximation.cost();
this.sidewaysCollector = sidewaysCollector;
this.sidewaysLeafCollector = sidewaysLeafCollector;
}
}
}

View File

@ -143,8 +143,14 @@ public class FacetsCollector extends SimpleCollector {
@Override
public void finish() throws IOException {
matchingDocs.add(new MatchingDocs(this.context, docsBuilder.build(), totalHits, scores));
DocIdSet bits;
if (docsBuilder != null) {
bits = docsBuilder.build();
docsBuilder = null;
} else {
bits = DocIdSet.EMPTY;
}
matchingDocs.add(new MatchingDocs(this.context, bits, totalHits, scores));
scores = null;
context = null;
}

View File

@ -64,6 +64,7 @@ import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCachingPolicy;
@ -289,7 +290,7 @@ public class TestDrillSideways extends FacetTestCase {
FacetsCollector baseFC = new FacetsCollector();
FacetsCollector dimFC = new FacetsCollector();
DrillSidewaysScorer.DocsAndCost docsAndCost =
new DrillSidewaysScorer.DocsAndCost(dimScorer, dimFC);
new DrillSidewaysScorer.DocsAndCost(dimScorer, dimFC.getLeafCollector(ctx));
LeafCollector baseCollector =
new LeafCollector() {
@ -313,7 +314,7 @@ public class TestDrillSideways extends FacetTestCase {
new DrillSidewaysScorer(
ctx,
baseScorer,
baseFC,
baseFC.getLeafCollector(ctx),
new DrillSidewaysScorer.DocsAndCost[] {docsAndCost},
scoreSubDocsAtOnce);
expectThrows(CollectionTerminatedException.class, () -> scorer.score(baseCollector, null));
@ -407,23 +408,7 @@ public class TestDrillSideways extends FacetTestCase {
try (IndexReader r = w.getReader();
TaxonomyReader taxoR = new DirectoryTaxonomyReader(taxoW)) {
// We can't use AssertingIndexSearcher unfortunately since it may randomly decide to bulk
// score a sub-range of docs instead of all docs at once. This is incompatible will drill
// sideways, so we have to do our own check here. This just makes sure we call #finish on
// the last leaf. It's too bad we need to do this and maybe some day we can clean this up
// by rethinking drill-sideways:
IndexSearcher searcher =
new IndexSearcher(r) {
@Override
protected void search(
List<LeafReaderContext> leaves, Weight weight, Collector collector)
throws IOException {
AssertingCollector assertingCollector = AssertingCollector.wrap(collector);
super.search(leaves, weight, assertingCollector);
assert assertingCollector.hasFinishedCollectingPreviousLeaf;
}
};
IndexSearcher searcher = new DrillSidewaysAssertingIndexSearcher(r);
Query baseQuery = new MatchAllDocsQuery();
DrillDownQuery ddq = new DrillDownQuery(facetsConfig, baseQuery);
@ -1469,6 +1454,66 @@ public class TestDrillSideways extends FacetTestCase {
IOUtils.close(r, tr, tw, d, td);
}
public void testFinishOnAllDimsNoHitsQuery() throws Exception {
try (Directory dir = newDirectory();
Directory taxoDir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
DirectoryTaxonomyWriter taxoW =
new DirectoryTaxonomyWriter(taxoDir, IndexWriterConfig.OpenMode.CREATE)) {
FacetsConfig facetsConfig = new FacetsConfig();
Document d = new Document();
d.add(new FacetField("foo", "bar"));
w.addDocument(facetsConfig.build(taxoW, d));
try (IndexReader r = w.getReader();
TaxonomyReader taxoR = new DirectoryTaxonomyReader(taxoW)) {
IndexSearcher searcher = new DrillSidewaysAssertingIndexSearcher(r);
// Creating a query that matches nothing to make sure #finish still gets called on all
// facet collectors:
Query baseQuery = new MatchNoDocsQuery();
DrillDownQuery ddq = new DrillDownQuery(facetsConfig, baseQuery);
ddq.add("foo", "bar");
DrillSideways drillSideways =
new DrillSideways(searcher, facetsConfig, taxoR) {
@Override
protected FacetsCollectorManager createDrillDownFacetsCollectorManager() {
return new AssertingFacetsCollectorManager();
}
@Override
protected FacetsCollectorManager createDrillSidewaysFacetsCollectorManager() {
return new AssertingFacetsCollectorManager();
}
};
SimpleCollectorManager cm =
new SimpleCollectorManager(10, Comparator.comparingInt(a -> a.docAndScore.doc));
DrillSideways.ConcurrentDrillSidewaysResult<List<DocAndScore>> result =
drillSideways.search(ddq, cm);
assertEquals(0, result.collectorResult.size());
// Make sure the "matching docs" are still populated with the appropriate leaf reader
// context, which happens as part of #finish getting called:
assertEquals(1, result.drillDownFacetsCollector.getMatchingDocs().size());
assertEquals(
1, result.drillDownFacetsCollector.getMatchingDocs().get(0).context.reader().maxDoc());
assertEquals(1, result.drillSidewaysFacetsCollector.length);
assertEquals(1, result.drillSidewaysFacetsCollector[0].getMatchingDocs().size());
assertEquals(
1,
result
.drillSidewaysFacetsCollector[0]
.getMatchingDocs()
.get(0)
.context
.reader()
.maxDoc());
}
}
}
private static class Counters {
int[][] counts;
@ -1620,6 +1665,59 @@ public class TestDrillSideways extends FacetTestCase {
}
}
// We can't use AssertingIndexSearcher unfortunately since it may randomly decide to bulk
// score a sub-range of docs instead of all docs at once. This is incompatible will drill
// sideways, so we have to do our own check here. This just makes sure we call #finish on
// the last leaf. It's too bad we need to do this and maybe some day we can clean this up
// by rethinking drill-sideways:
private static final class DrillSidewaysAssertingIndexSearcher extends IndexSearcher {
DrillSidewaysAssertingIndexSearcher(IndexReader r) {
super(r);
}
@Override
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector)
throws IOException {
AssertingCollector assertingCollector = AssertingCollector.wrap(collector);
super.search(leaves, weight, assertingCollector);
assert assertingCollector.hasFinishedCollectingPreviousLeaf;
}
}
private static final class AssertingFacetsCollectorManager extends FacetsCollectorManager {
@Override
public FacetsCollector newCollector() throws IOException {
return new AssertingFacetsCollector();
}
@Override
public FacetsCollector reduce(Collection<FacetsCollector> collectors) throws IOException {
for (FacetsCollector fc : collectors) {
assert fc instanceof AssertingFacetsCollector;
assert ((AssertingFacetsCollector) fc).hasFinishedLastLeaf == true;
}
return super.reduce(collectors);
}
}
private static final class AssertingFacetsCollector extends FacetsCollector {
private boolean hasFinishedLastLeaf = true;
@Override
protected void doSetNextReader(LeafReaderContext context) throws IOException {
assert hasFinishedLastLeaf;
hasFinishedLastLeaf = false;
super.doSetNextReader(context);
}
@Override
public void finish() throws IOException {
hasFinishedLastLeaf = true;
super.finish();
}
}
private int[] getTopNOrds(final int[] counts, final String[] values, int topN) {
final int[] ids = new int[counts.length];
for (int i = 0; i < ids.length; i++) {