diff --git a/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java b/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java index d4f9ab72973..65c5c0f707c 100644 --- a/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java +++ b/core/src/main/java/org/apache/lucene/queries/MinDocQuery.java @@ -66,46 +66,54 @@ public final class MinDocQuery extends Query { return null; } final int segmentMinDoc = Math.max(0, minDoc - context.docBase); - final DocIdSetIterator disi = new DocIdSetIterator() { - - int doc = -1; - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert target > doc; - if (doc == -1) { - // skip directly to minDoc - doc = Math.max(target, segmentMinDoc); - } else { - doc = target; - } - if (doc >= maxDoc) { - doc = NO_MORE_DOCS; - } - return doc; - } - - @Override - public long cost() { - return maxDoc - segmentMinDoc; - } - - }; + final DocIdSetIterator disi = new MinDocIterator(segmentMinDoc, maxDoc); return new ConstantScoreScorer(this, score(), disi); } }; } + static class MinDocIterator extends DocIdSetIterator { + final int segmentMinDoc; + final int maxDoc; + int doc = -1; + + MinDocIterator(int segmentMinDoc, int maxDoc) { + this.segmentMinDoc = segmentMinDoc; + this.maxDoc = maxDoc; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + return advance(doc + 1); + } + + @Override + public int advance(int target) throws IOException { + assert target > doc; + if (doc == -1) { + // skip directly to minDoc + doc = Math.max(target, segmentMinDoc); + } else { + doc = target; + } + if (doc >= maxDoc) { + doc = NO_MORE_DOCS; + } + return doc; + } + + @Override + public long cost() { + return maxDoc - segmentMinDoc; + } + } + + @Override public String toString(String field) { return "MinDocQuery(minDoc=" + minDoc + ")"; diff --git a/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java b/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java new file mode 100644 index 00000000000..b9ed2290350 --- /dev/null +++ b/core/src/main/java/org/apache/lucene/queries/SearchAfterSortedDocQuery.java @@ -0,0 +1,165 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.lucene.queries; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.ConstantScoreScorer; +import org.apache.lucene.search.ConstantScoreWeight; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.EarlyTerminatingSortingCollector; +import org.apache.lucene.search.FieldComparator; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.LeafFieldComparator; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.Weight; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * A {@link Query} that only matches documents that are greater than the provided {@link FieldDoc}. + * This works only if the index is sorted according to the given search {@link Sort}. + */ +public class SearchAfterSortedDocQuery extends Query { + private final Sort sort; + private final FieldDoc after; + private final FieldComparator[] fieldComparators; + private final int[] reverseMuls; + + public SearchAfterSortedDocQuery(Sort sort, FieldDoc after) { + if (sort.getSort().length != after.fields.length) { + throw new IllegalArgumentException("after doc has " + after.fields.length + " value(s) but sort has " + + sort.getSort().length + "."); + } + this.sort = sort; + this.after = after; + int numFields = sort.getSort().length; + this.fieldComparators = new FieldComparator[numFields]; + this.reverseMuls = new int[numFields]; + for (int i = 0; i < numFields; i++) { + SortField sortField = sort.getSort()[i]; + FieldComparator fieldComparator = sortField.getComparator(1, i); + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) fieldComparator; + comparator.setTopValue(after.fields[i]); + fieldComparators[i] = fieldComparator; + reverseMuls[i] = sortField.getReverse() ? -1 : 1; + } + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { + return new ConstantScoreWeight(this, 1.0f) { + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Sort segmentSort = context.reader().getMetaData().getSort(); + if (EarlyTerminatingSortingCollector.canEarlyTerminate(sort, segmentSort) == false) { + throw new IOException("search sort :[" + sort.getSort() + "] does not match the index sort:[" + segmentSort + "]"); + } + final int afterDoc = after.doc - context.docBase; + TopComparator comparator= getTopComparator(fieldComparators, reverseMuls, context, afterDoc); + final int maxDoc = context.reader().maxDoc(); + final int firstDoc = searchAfterDoc(comparator, 0, context.reader().maxDoc()); + if (firstDoc >= maxDoc) { + return null; + } + final DocIdSetIterator disi = new MinDocQuery.MinDocIterator(firstDoc, maxDoc); + return new ConstantScoreScorer(this, score(), disi); + } + }; + } + + @Override + public String toString(String field) { + return "SearchAfterSortedDocQuery(sort=" + sort + ", afterDoc=" + after.toString() + ")"; + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && + equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(SearchAfterSortedDocQuery other) { + return sort.equals(other.sort) && + after.doc == other.after.doc && + Double.compare(after.score, other.after.score) == 0 && + Arrays.equals(after.fields, other.after.fields); + } + + @Override + public int hashCode() { + return Objects.hash(classHash(), sort, after.doc, after.score, Arrays.hashCode(after.fields)); + } + + interface TopComparator { + boolean lessThanTop(int doc) throws IOException; + } + + static TopComparator getTopComparator(FieldComparator[] fieldComparators, + int[] reverseMuls, + LeafReaderContext leafReaderContext, + int topDoc) { + return doc -> { + // DVs use forward iterators so we recreate the iterator for each sort field + // every time we need to compare a document with the after doc. + // We could reuse the iterators when the comparison goes forward but + // this should only be called a few time per segment (binary search). + for (int i = 0; i < fieldComparators.length; i++) { + LeafFieldComparator comparator = fieldComparators[i].getLeafComparator(leafReaderContext); + int value = reverseMuls[i] * comparator.compareTop(doc); + if (value != 0) { + return value < 0; + } + } + + if (topDoc <= doc) { + return false; + } + return true; + }; + } + + /** + * Returns the first doc id greater than the provided after doc. + */ + static int searchAfterDoc(TopComparator comparator, int from, int to) throws IOException { + int low = from; + int high = to - 1; + + while (low <= high) { + int mid = (low + high) >>> 1; + if (comparator.lessThanTop(mid)) { + high = mid - 1; + } else { + low = mid + 1; + } + } + return low; + } + +} diff --git a/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 10c180f687e..82e572a180e 100644 --- a/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/core/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -21,11 +21,13 @@ package org.elasticsearch.search.query; import org.apache.lucene.index.IndexReader; import org.apache.lucene.queries.MinDocQuery; +import org.apache.lucene.queries.SearchAfterSortedDocQuery; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.EarlyTerminatingSortingCollector; +import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.Query; @@ -50,7 +52,6 @@ import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.search.suggest.SuggestPhase; import java.util.LinkedList; -import java.util.List; import static org.elasticsearch.search.query.QueryCollectorContext.createCancellableCollectorContext; import static org.elasticsearch.search.query.QueryCollectorContext.createEarlySortingTerminationCollectorContext; @@ -130,16 +131,17 @@ public class QueryPhase implements SearchPhase { final ScrollContext scrollContext = searchContext.scrollContext(); if (scrollContext != null) { - if (returnsDocsInOrder(query, searchContext.sort())) { - if (scrollContext.totalHits == -1) { - // first round - assert scrollContext.lastEmittedDoc == null; - // there is not much that we can optimize here since we want to collect all - // documents in order to get the total number of hits - } else { + if (scrollContext.totalHits == -1) { + // first round + assert scrollContext.lastEmittedDoc == null; + // there is not much that we can optimize here since we want to collect all + // documents in order to get the total number of hits + + } else { + final ScoreDoc after = scrollContext.lastEmittedDoc; + if (returnsDocsInOrder(query, searchContext.sort())) { // now this gets interesting: since we sort in index-order, we can directly // skip to the desired doc - final ScoreDoc after = scrollContext.lastEmittedDoc; if (after != null) { BooleanQuery bq = new BooleanQuery.Builder() .add(query, BooleanClause.Occur.MUST) @@ -150,6 +152,17 @@ public class QueryPhase implements SearchPhase { // ... and stop collecting after ${size} matches searchContext.terminateAfter(searchContext.size()); searchContext.trackTotalHits(false); + } else if (canEarlyTerminate(indexSort, searchContext)) { + // now this gets interesting: since the index sort matches the search sort, we can directly + // skip to the desired doc + if (after != null) { + BooleanQuery bq = new BooleanQuery.Builder() + .add(query, BooleanClause.Occur.MUST) + .add(new SearchAfterSortedDocQuery(indexSort, (FieldDoc) after), BooleanClause.Occur.FILTER) + .build(); + query = bq; + } + searchContext.trackTotalHits(false); } } } @@ -189,7 +202,10 @@ public class QueryPhase implements SearchPhase { final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, reader, collectors.stream().anyMatch(QueryCollectorContext::shouldCollect)); final boolean shouldCollect = topDocsFactory.shouldCollect(); - if (scrollContext == null && topDocsFactory.numHits() > 0 && canEarlyTerminate(indexSort, searchContext)) { + + if (topDocsFactory.numHits() > 0 && + (scrollContext == null || scrollContext.totalHits != -1) && + canEarlyTerminate(indexSort, searchContext)) { // top docs collection can be early terminated based on index sort // add the collector context first so we don't early terminate aggs but only top docs collectors.addFirst(createEarlySortingTerminationCollectorContext(reader, searchContext.query(), indexSort, diff --git a/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java b/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java new file mode 100644 index 00000000000..25c5ff6fa21 --- /dev/null +++ b/core/src/test/java/org/apache/lucene/queries/SearchAfterSortedDocQueryTests.java @@ -0,0 +1,130 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.lucene.queries; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.SortedDocValuesField; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.index.ReaderUtil; +import org.apache.lucene.search.FieldDoc; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.QueryUtils; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.SortedNumericSortField; +import org.apache.lucene.search.SortedSetSortField; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.elasticsearch.test.ESTestCase; + +import java.io.IOException; + +import static org.hamcrest.Matchers.equalTo; + +public class SearchAfterSortedDocQueryTests extends ESTestCase { + + public void testBasics() { + Sort sort1 = new Sort( + new SortedNumericSortField("field1", SortField.Type.INT), + new SortedSetSortField("field2", false) + ); + Sort sort2 = new Sort( + new SortedNumericSortField("field1", SortField.Type.INT), + new SortedSetSortField("field3", false) + ); + FieldDoc fieldDoc1 = new FieldDoc(0, 0f, new Object[]{5, new BytesRef("foo")}); + FieldDoc fieldDoc2 = new FieldDoc(0, 0f, new Object[]{5, new BytesRef("foo")}); + + SearchAfterSortedDocQuery query1 = new SearchAfterSortedDocQuery(sort1, fieldDoc1); + SearchAfterSortedDocQuery query2 = new SearchAfterSortedDocQuery(sort1, fieldDoc2); + SearchAfterSortedDocQuery query3 = new SearchAfterSortedDocQuery(sort2, fieldDoc2); + QueryUtils.check(query1); + QueryUtils.checkEqual(query1, query2); + QueryUtils.checkUnequal(query1, query3); + } + + public void testInvalidSort() { + Sort sort = new Sort(new SortedNumericSortField("field1", SortField.Type.INT)); + FieldDoc fieldDoc = new FieldDoc(0, 0f, new Object[] {4, 5}); + IllegalArgumentException ex = + expectThrows(IllegalArgumentException.class, () -> new SearchAfterSortedDocQuery(sort, fieldDoc)); + assertThat(ex.getMessage(), equalTo("after doc has 2 value(s) but sort has 1.")); + } + + public void testRandom() throws IOException { + final int numDocs = randomIntBetween(100, 200); + final Document doc = new Document(); + final Directory dir = newDirectory(); + Sort sort = new Sort( + new SortedNumericSortField("number1", SortField.Type.INT, randomBoolean()), + new SortField("string", SortField.Type.STRING, randomBoolean()) + ); + final IndexWriterConfig config = new IndexWriterConfig(); + config.setIndexSort(sort); + final RandomIndexWriter w = new RandomIndexWriter(random(), dir, config); + for (int i = 0; i < numDocs; ++i) { + int rand = randomIntBetween(0, 10); + doc.add(new SortedNumericDocValuesField("number", rand)); + doc.add(new SortedDocValuesField("string", new BytesRef(randomAlphaOfLength(randomIntBetween(5, 50))))); + w.addDocument(doc); + doc.clear(); + if (rarely()) { + w.commit(); + } + } + final IndexReader reader = w.getReader(); + final IndexSearcher searcher = newSearcher(reader); + + int step = randomIntBetween(1, 10); + FixedBitSet bitSet = new FixedBitSet(numDocs); + TopDocs topDocs = null; + for (int i = 0; i < numDocs;) { + if (topDocs != null) { + FieldDoc after = (FieldDoc) topDocs.scoreDocs[topDocs.scoreDocs.length - 1]; + topDocs = searcher.search(new SearchAfterSortedDocQuery(sort, after), step, sort); + } else { + topDocs = searcher.search(new MatchAllDocsQuery(), step, sort); + } + i += step; + for (ScoreDoc topDoc : topDocs.scoreDocs) { + int readerIndex = ReaderUtil.subIndex(topDoc.doc, reader.leaves()); + final LeafReaderContext leafReaderContext = reader.leaves().get(readerIndex); + int docRebase = topDoc.doc - leafReaderContext.docBase; + if (leafReaderContext.reader().hasDeletions()) { + assertTrue(leafReaderContext.reader().getLiveDocs().get(docRebase)); + } + assertFalse(bitSet.get(topDoc.doc)); + bitSet.set(topDoc.doc); + } + } + assertThat(bitSet.cardinality(), equalTo(reader.numDocs())); + w.close(); + reader.close(); + dir.close(); + } +} diff --git a/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java b/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java index 2633cb706e0..b05c6dff04b 100644 --- a/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java +++ b/core/src/test/java/org/elasticsearch/search/query/QueryPhaseTests.java @@ -36,12 +36,12 @@ import org.apache.lucene.search.BooleanClause.Occur; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; import org.apache.lucene.search.ConstantScoreQuery; +import org.apache.lucene.search.FieldComparator; import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; -import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.TermQuery; @@ -50,10 +50,8 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.elasticsearch.action.search.SearchTask; import org.elasticsearch.index.query.ParsedQuery; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.internal.ScrollContext; -import org.elasticsearch.search.internal.SearchContext; import org.elasticsearch.search.sort.SortAndFormats; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.TestSearchContext; @@ -64,11 +62,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThan; -import static org.hamcrest.Matchers.nullValue; public class QueryPhaseTests extends ESTestCase { @@ -440,4 +436,71 @@ public class QueryPhaseTests extends ESTestCase { reader.close(); dir.close(); } + + public void testIndexSortScrollOptimization() throws Exception { + Directory dir = newDirectory(); + final Sort sort = new Sort( + new SortField("rank", SortField.Type.INT), + new SortField("tiebreaker", SortField.Type.INT) + ); + IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); + RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); + final int numDocs = scaledRandomIntBetween(100, 200); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + doc.add(new NumericDocValuesField("rank", random().nextInt())); + doc.add(new NumericDocValuesField("tiebreaker", i)); + w.addDocument(doc); + } + w.close(); + + TestSearchContext context = new TestSearchContext(null); + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + ScrollContext scrollContext = new ScrollContext(); + scrollContext.lastEmittedDoc = null; + scrollContext.maxScore = Float.NaN; + scrollContext.totalHits = -1; + context.scrollContext(scrollContext); + context.setTask(new SearchTask(123L, "", "", "", null)); + context.setSize(10); + context.sort(new SortAndFormats(sort, new DocValueFormat[] {DocValueFormat.RAW, DocValueFormat.RAW})); + + final AtomicBoolean collected = new AtomicBoolean(); + final IndexReader reader = DirectoryReader.open(dir); + IndexSearcher contextSearcher = new IndexSearcher(reader) { + protected void search(List leaves, Weight weight, Collector collector) throws IOException { + collected.set(true); + super.search(leaves, weight, collector); + } + }; + + QueryPhase.execute(context, contextSearcher, sort); + assertThat(context.queryResult().topDocs().totalHits, equalTo(numDocs)); + assertTrue(collected.get()); + assertNull(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits(), equalTo(numDocs)); + int sizeMinus1 = context.queryResult().topDocs().scoreDocs.length - 1; + FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().scoreDocs[sizeMinus1]; + + QueryPhase.execute(context, contextSearcher, sort); + assertThat(context.queryResult().topDocs().totalHits, equalTo(numDocs)); + assertTrue(collected.get()); + assertTrue(context.queryResult().terminatedEarly()); + assertThat(context.terminateAfter(), equalTo(0)); + assertThat(context.queryResult().getTotalHits(), equalTo(numDocs)); + FieldDoc firstDoc = (FieldDoc) context.queryResult().topDocs().scoreDocs[0]; + for (int i = 0; i < sort.getSort().length; i++) { + @SuppressWarnings("unchecked") + FieldComparator comparator = (FieldComparator) sort.getSort()[i].getComparator(1, i); + int cmp = comparator.compareValues(firstDoc.fields[i], lastDoc.fields[i]); + if (cmp == 0) { + continue; + } + assertThat(cmp, equalTo(1)); + break; + } + reader.close(); + dir.close(); + } }