Fix bug in sort optimization (#1903)

Fix bug how iterator with skipping functionality
advances and produces docs

Relates to #1725
This commit is contained in:
Mayya Sharipova 2020-09-23 09:09:43 -04:00 committed by GitHub
parent e19239d96b
commit 7d90b858c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 35 deletions

View File

@ -207,13 +207,13 @@ public abstract class Weight implements SegmentCacheable {
// if possible filter scorerIterator to keep only competitive docs as defined by collector // if possible filter scorerIterator to keep only competitive docs as defined by collector
DocIdSetIterator filteredIterator = collectorIterator == null ? scorerIterator : DocIdSetIterator filteredIterator = collectorIterator == null ? scorerIterator :
ConjunctionDISI.intersectIterators(Arrays.asList(scorerIterator, collectorIterator)); ConjunctionDISI.intersectIterators(Arrays.asList(scorerIterator, collectorIterator));
if (scorer.docID() == -1 && min == 0 && max == DocIdSetIterator.NO_MORE_DOCS) { if (filteredIterator.docID() == -1 && min == 0 && max == DocIdSetIterator.NO_MORE_DOCS) {
scoreAll(collector, filteredIterator, twoPhase, acceptDocs); scoreAll(collector, filteredIterator, twoPhase, acceptDocs);
return DocIdSetIterator.NO_MORE_DOCS; return DocIdSetIterator.NO_MORE_DOCS;
} else { } else {
int doc = scorer.docID(); int doc = filteredIterator.docID();
if (doc < min) { if (doc < min) {
doc = scorerIterator.advance(min); doc = filteredIterator.advance(min);
} }
return scoreRange(collector, filteredIterator, twoPhase, acceptDocs, doc, max); return scoreRange(collector, filteredIterator, twoPhase, acceptDocs, doc, max);
} }

View File

@ -133,16 +133,14 @@ public class DocComparator extends FieldComparator<Integer> {
return null; return null;
} else { } else {
return new DocIdSetIterator() { return new DocIdSetIterator() {
private int doc;
@Override @Override
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
return doc = competitiveIterator.nextDoc(); return competitiveIterator.nextDoc();
} }
@Override @Override
public int docID() { public int docID() {
return doc; return competitiveIterator.docID();
} }
@Override @Override
@ -152,7 +150,7 @@ public class DocComparator extends FieldComparator<Integer> {
@Override @Override
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
return doc = competitiveIterator.advance(target); return competitiveIterator.advance(target);
} }
}; };
} }
@ -176,7 +174,7 @@ public class DocComparator extends FieldComparator<Integer> {
if (docBase + maxDoc <= minDoc) { if (docBase + maxDoc <= minDoc) {
competitiveIterator = DocIdSetIterator.empty(); // skip this segment competitiveIterator = DocIdSetIterator.empty(); // skip this segment
} else { } else {
int segmentMinDoc = Math.max(0, minDoc - docBase); int segmentMinDoc = Math.max(competitiveIterator.docID(), minDoc - docBase);
competitiveIterator = new MinDocIterator(segmentMinDoc, maxDoc); competitiveIterator = new MinDocIterator(segmentMinDoc, maxDoc);
} }
} }

View File

@ -220,16 +220,14 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
public DocIdSetIterator competitiveIterator() { public DocIdSetIterator competitiveIterator() {
if (enableSkipping == false) return null; if (enableSkipping == false) return null;
return new DocIdSetIterator() { return new DocIdSetIterator() {
private int doc;
@Override @Override
public int nextDoc() throws IOException { public int nextDoc() throws IOException {
return doc = competitiveIterator.nextDoc(); return competitiveIterator.nextDoc();
} }
@Override @Override
public int docID() { public int docID() {
return doc; return competitiveIterator.docID();
} }
@Override @Override
@ -239,7 +237,7 @@ public abstract class NumericComparator<T extends Number> extends FieldComparato
@Override @Override
public int advance(int target) throws IOException { public int advance(int target) throws IOException {
return doc = competitiveIterator.advance(target); return competitiveIterator.advance(target);
} }
}; };
} }

View File

@ -53,6 +53,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
if (i == 7000) writer.flush(); // two segments if (i == 7000) writer.flush(); // two segments
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
writer.close();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
final SortField sortField = new SortField("my_field", SortField.Type.LONG); final SortField sortField = new SortField("my_field", SortField.Type.LONG);
final Sort sort = new Sort(sortField); final Sort sort = new Sort(sortField);
@ -110,7 +111,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run
} }
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -131,6 +131,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
writer.addDocument(doc); writer.addDocument(doc);
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
writer.close();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
final SortField sortField = new SortField("my_field", SortField.Type.LONG); final SortField sortField = new SortField("my_field", SortField.Type.LONG);
final Sort sort = new Sort(sortField); final Sort sort = new Sort(sortField);
@ -147,7 +148,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
} }
assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -167,6 +167,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
if (i == 7000) writer.flush(); // two segments if (i == 7000) writer.flush(); // two segments
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
writer.close();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
final int numHits = 3; final int numHits = 3;
final int totalHitsThreshold = 3; final int totalHitsThreshold = 3;
@ -192,7 +193,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
assertTrue(topDocs.totalHits.value < numDocs); // assert that some docs were skipped => optimization was run assertTrue(topDocs.totalHits.value < numDocs); // assert that some docs were skipped => optimization was run
} }
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -210,6 +210,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
if (i == 7000) writer.flush(); // two segments if (i == 7000) writer.flush(); // two segments
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
writer.close();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
final int numHits = 3; final int numHits = 3;
final int totalHitsThreshold = 3; final int totalHitsThreshold = 3;
@ -261,7 +262,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run assertEquals(topDocs.totalHits.value, numDocs); // assert that all documents were collected => optimization was not run
} }
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -279,6 +279,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
writer.addDocument(doc); writer.addDocument(doc);
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
writer.close();
IndexSearcher searcher = new IndexSearcher(reader); IndexSearcher searcher = new IndexSearcher(reader);
final SortField sortField = new SortField("my_field", SortField.Type.FLOAT); final SortField sortField = new SortField("my_field", SortField.Type.FLOAT);
final Sort sort = new Sort(sortField); final Sort sort = new Sort(sortField);
@ -298,7 +299,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
assertTrue(topDocs.totalHits.value < numDocs); assertTrue(topDocs.totalHits.value < numDocs);
} }
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -311,14 +311,16 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
final Document doc = new Document(); final Document doc = new Document();
writer.addDocument(doc); writer.addDocument(doc);
if ((i > 0) && (i % 50 == 0)) { if ((i > 0) && (i % 50 == 0)) {
writer.commit(); writer.flush();
} }
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
IndexSearcher searcher = new IndexSearcher(reader); writer.close();
final int numHits = 3; IndexSearcher searcher = newSearcher(reader);
final int totalHitsThreshold = 3; final int numHits = 10;
final int[] searchAfters = {10, 140, numDocs - 4}; final int totalHitsThreshold = 10;
final int[] searchAfters = {3, 10, numDocs - 10};
for (int searchAfter : searchAfters) { for (int searchAfter : searchAfters) {
// sort by _doc with search after should trigger optimization // sort by _doc with search after should trigger optimization
{ {
@ -327,14 +329,15 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold); final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold);
searcher.search(new MatchAllDocsQuery(), collector); searcher.search(new MatchAllDocsQuery(), collector);
TopDocs topDocs = collector.topDocs(); TopDocs topDocs = collector.topDocs();
assertEquals(numHits, topDocs.scoreDocs.length); int expNumHits = (searchAfter >= (numDocs - numHits)) ? (numDocs - searchAfter - 1) : numHits;
for (int i = 0; i < numHits; i++) { assertEquals(expNumHits, topDocs.scoreDocs.length);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
int expectedDocID = searchAfter + 1 + i; int expectedDocID = searchAfter + 1 + i;
assertEquals(expectedDocID, topDocs.scoreDocs[i].doc); assertEquals(expectedDocID, topDocs.scoreDocs[i].doc);
} }
assertTrue(collector.isEarlyTerminated()); assertTrue(collector.isEarlyTerminated());
// check that very few docs were collected // check that very few docs were collected
assertTrue(topDocs.totalHits.value < 10); assertTrue(topDocs.totalHits.value < numDocs);
} }
// sort by _doc + _score with search after should trigger optimization // sort by _doc + _score with search after should trigger optimization
@ -344,14 +347,15 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold); final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold);
searcher.search(new MatchAllDocsQuery(), collector); searcher.search(new MatchAllDocsQuery(), collector);
TopDocs topDocs = collector.topDocs(); TopDocs topDocs = collector.topDocs();
assertEquals(numHits, topDocs.scoreDocs.length); int expNumHits = (searchAfter >= (numDocs - numHits)) ? (numDocs - searchAfter - 1) : numHits;
for (int i = 0; i < numHits; i++) { assertEquals(expNumHits, topDocs.scoreDocs.length);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
int expectedDocID = searchAfter + 1 + i; int expectedDocID = searchAfter + 1 + i;
assertEquals(expectedDocID, topDocs.scoreDocs[i].doc); assertEquals(expectedDocID, topDocs.scoreDocs[i].doc);
} }
assertTrue(collector.isEarlyTerminated()); assertTrue(collector.isEarlyTerminated());
// assert that very few docs were collected // assert that very few docs were collected
assertTrue(topDocs.totalHits.value < 10); assertTrue(topDocs.totalHits.value < numDocs);
} }
// sort by _doc desc should not trigger optimization // sort by _doc desc should not trigger optimization
@ -361,8 +365,9 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold); final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, after, totalHitsThreshold);
searcher.search(new MatchAllDocsQuery(), collector); searcher.search(new MatchAllDocsQuery(), collector);
TopDocs topDocs = collector.topDocs(); TopDocs topDocs = collector.topDocs();
assertEquals(numHits, topDocs.scoreDocs.length); int expNumHits = (searchAfter < numHits) ? searchAfter : numHits;
for (int i = 0; i < numHits; i++) { assertEquals(expNumHits, topDocs.scoreDocs.length);
for (int i = 0; i < topDocs.scoreDocs.length; i++) {
int expectedDocID = searchAfter - 1 - i; int expectedDocID = searchAfter - 1 - i;
assertEquals(expectedDocID, topDocs.scoreDocs[i].doc); assertEquals(expectedDocID, topDocs.scoreDocs[i].doc);
} }
@ -371,7 +376,6 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
} }
} }
writer.close();
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -389,12 +393,13 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
doc.add(new StringField("tf", "seg" + seg, Field.Store.YES)); doc.add(new StringField("tf", "seg" + seg, Field.Store.YES));
writer.addDocument(doc); writer.addDocument(doc);
if ((i > 0) && (i % 50 == 0)) { if ((i > 0) && (i % 50 == 0)) {
writer.commit(); writer.flush();
seg++; seg++;
} }
} }
final IndexReader reader = DirectoryReader.open(writer); final IndexReader reader = DirectoryReader.open(writer);
IndexSearcher searcher = new IndexSearcher(reader); writer.close();
final int numHits = 3; final int numHits = 3;
final int totalHitsThreshold = 3; final int totalHitsThreshold = 3;
final Sort sort = new Sort(FIELD_DOC); final Sort sort = new Sort(FIELD_DOC);
@ -402,6 +407,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
// sort by _doc should skip all non-competitive documents // sort by _doc should skip all non-competitive documents
{ {
final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, null, totalHitsThreshold); final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, null, totalHitsThreshold);
IndexSearcher searcher = newSearcher(reader);
searcher.search(new MatchAllDocsQuery(), collector); searcher.search(new MatchAllDocsQuery(), collector);
TopDocs topDocs = collector.topDocs(); TopDocs topDocs = collector.topDocs();
assertEquals(numHits, topDocs.scoreDocs.length); assertEquals(numHits, topDocs.scoreDocs.length);
@ -419,6 +425,7 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
BooleanQuery.Builder bq = new BooleanQuery.Builder(); BooleanQuery.Builder bq = new BooleanQuery.Builder();
bq.add(LongPoint.newRangeQuery("lf", lowerRange, Long.MAX_VALUE), BooleanClause.Occur.MUST); bq.add(LongPoint.newRangeQuery("lf", lowerRange, Long.MAX_VALUE), BooleanClause.Occur.MUST);
bq.add(new TermQuery(new Term("tf", "seg1")), BooleanClause.Occur.MUST); bq.add(new TermQuery(new Term("tf", "seg1")), BooleanClause.Occur.MUST);
IndexSearcher searcher = newSearcher(reader);
searcher.search(bq.build(), collector); searcher.search(bq.build(), collector);
TopDocs topDocs = collector.topDocs(); TopDocs topDocs = collector.topDocs();
@ -432,7 +439,48 @@ public class TestFieldSortOptimizationSkipping extends LuceneTestCase {
assertTrue(topDocs.totalHits.value < 10); // assert that very few docs were collected assertTrue(topDocs.totalHits.value < 10); // assert that very few docs were collected
} }
reader.close();
dir.close();
}
/**
* Test that sorting on _doc works correctly.
* This test goes through DefaultBulkSorter::scoreRange, where scorerIterator is BitSetIterator.
* As a conjunction of this BitSetIterator with DocComparator's iterator, we get BitSetConjunctionDISI.
* BitSetConjuctionDISI advances based on the DocComparator's iterator, and doesn't consider
* that its BitSetIterator may have advanced passed a certain doc.
*/
public void testDocSort() throws IOException {
final Directory dir = newDirectory();
final IndexWriter writer = new IndexWriter(dir, new IndexWriterConfig());
final int numDocs = 4;
for (int i = 0; i < numDocs; ++i) {
final Document doc = new Document();
doc.add(new StringField("id", "id" + i, Field.Store.NO));
if (i < 2) {
doc.add(new LongPoint("lf", 1));
}
writer.addDocument(doc);
}
final IndexReader reader = DirectoryReader.open(writer);
writer.close(); writer.close();
IndexSearcher searcher = newSearcher(reader);
searcher.setQueryCache(null);
final int numHits = 10;
final int totalHitsThreshold = 10;
final Sort sort = new Sort(FIELD_DOC);
{
final TopFieldCollector collector = TopFieldCollector.create(sort, numHits, null, totalHitsThreshold);
BooleanQuery.Builder bq = new BooleanQuery.Builder();
bq.add(LongPoint.newExactQuery("lf", 1), BooleanClause.Occur.MUST);
bq.add(new TermQuery(new Term("id", "id3")), BooleanClause.Occur.MUST_NOT);
searcher.search(bq.build(), collector);
TopDocs topDocs = collector.topDocs();
assertEquals(2, topDocs.scoreDocs.length);
}
reader.close(); reader.close();
dir.close(); dir.close();
} }