From bade484998e2d933cdc02248ba50a9eb2e54aaa8 Mon Sep 17 00:00:00 2001 From: Luca Cavanna Date: Thu, 3 Feb 2022 17:19:05 +0100 Subject: [PATCH] LUCENE-10385: Implement Weight#count on IndexSortSortedNumericDocValuesRangeQuery (#635) IndexSortSortedNumericDocValuesRangeQuery can implement its count method and coompute count through a binary search, the same binary search that is used to execute the query itself, whenever all the required conditions are met. --- lucene/CHANGES.txt | 3 + ...xSortSortedNumericDocValuesRangeQuery.java | 46 +++-- ...xSortSortedNumericDocValuesRangeQuery.java | 166 ++++++++++++------ 3 files changed, 146 insertions(+), 69 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 727c0fa5e9d..d162714036f 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -128,6 +128,9 @@ New Features based on TotalHitCountCollector that allows users to parallelize counting the number of hits. (Luca Cavanna, Adrien Grand) +* LUCENE-10385: Implement Weight#count on IndexSortSortedNumericDocValuesRangeQuery + to speed up computing the number of hits when possible. (Luca Cavanna, Adrien Grand) + Improvements --------------------- diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/IndexSortSortedNumericDocValuesRangeQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/IndexSortSortedNumericDocValuesRangeQuery.java index 829ed712469..a35352564a9 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/IndexSortSortedNumericDocValuesRangeQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/IndexSortSortedNumericDocValuesRangeQuery.java @@ -156,20 +156,9 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query { return new ConstantScoreWeight(this, boost) { @Override public Scorer scorer(LeafReaderContext context) throws IOException { - SortedNumericDocValues sortedNumericValues = - DocValues.getSortedNumeric(context.reader(), field); - NumericDocValues numericValues = DocValues.unwrapSingleton(sortedNumericValues); - - if (numericValues != null) { - Sort indexSort = context.reader().getMetaData().getSort(); - if (indexSort != null - && indexSort.getSort().length > 0 - && indexSort.getSort()[0].getField().equals(field)) { - - SortField sortField = indexSort.getSort()[0]; - DocIdSetIterator disi = getDocIdSetIterator(sortField, context, numericValues); - return new ConstantScoreScorer(this, score(), scoreMode, disi); - } + DocIdSetIterator disi = getDocIdSetIteratorOrNull(context); + if (disi != null) { + return new ConstantScoreScorer(this, score(), scoreMode, disi); } return fallbackWeight.scorer(context); } @@ -180,9 +169,36 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query { // if the fallback query is cacheable. return fallbackWeight.isCacheable(ctx); } + + @Override + public int count(LeafReaderContext context) throws IOException { + BoundedDocSetIdIterator disi = getDocIdSetIteratorOrNull(context); + if (disi != null) { + return disi.lastDoc - disi.firstDoc; + } + return fallbackWeight.count(context); + } }; } + private BoundedDocSetIdIterator getDocIdSetIteratorOrNull(LeafReaderContext context) + throws IOException { + SortedNumericDocValues sortedNumericValues = + DocValues.getSortedNumeric(context.reader(), field); + NumericDocValues numericValues = DocValues.unwrapSingleton(sortedNumericValues); + if (numericValues != null) { + Sort indexSort = context.reader().getMetaData().getSort(); + if (indexSort != null + && indexSort.getSort().length > 0 + && indexSort.getSort()[0].getField().equals(field)) { + + SortField sortField = indexSort.getSort()[0]; + return getDocIdSetIterator(sortField, context, numericValues); + } + } + return null; + } + /** * Computes the document IDs that lie within the range [lowerValue, upperValue] by performing * binary search on the field's doc values. @@ -195,7 +211,7 @@ public class IndexSortSortedNumericDocValuesRangeQuery extends Query { * {@link DocIdSetIterator} makes sure to wrap the original docvalues to skip over documents with * no value. */ - private DocIdSetIterator getDocIdSetIterator( + private BoundedDocSetIdIterator getDocIdSetIterator( SortField sortField, LeafReaderContext context, DocIdSetIterator delegate) throws IOException { long lower = sortField.getReverse() ? upperValue : lowerValue; diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java index fb8ecf26007..f876dddd0bc 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/search/TestIndexSortSortedNumericDocValuesRangeQuery.java @@ -38,6 +38,7 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; import org.apache.lucene.search.SortedNumericSortField; import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHitCountCollectorManager; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.analysis.MockAnalyzer; @@ -95,7 +96,7 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas } } - private void assertSameHits(IndexSearcher searcher, Query q1, Query q2, boolean scores) + private static void assertSameHits(IndexSearcher searcher, Query q1, Query q2, boolean scores) throws IOException { final int maxDoc = searcher.getIndexReader().maxDoc(); final TopDocs td1 = searcher.search(q1, maxDoc, scores ? Sort.RELEVANCE : Sort.INDEXORDER); @@ -167,43 +168,50 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas IndexSearcher searcher = newSearcher(reader); // Test ranges consisting of one value. - assertEquals(1, searcher.count(createQuery("field", -80, -80))); - assertEquals(1, searcher.count(createQuery("field", -5, -5))); - assertEquals(2, searcher.count(createQuery("field", 0, 0))); - assertEquals(1, searcher.count(createQuery("field", 30, 30))); - assertEquals(1, searcher.count(createQuery("field", 35, 35))); + assertNumberOfHits(searcher, createQuery("field", -80, -80), 1); + assertNumberOfHits(searcher, createQuery("field", -5, -5), 1); + assertNumberOfHits(searcher, createQuery("field", 0, 0), 2); + assertNumberOfHits(searcher, createQuery("field", 30, 30), 1); + assertNumberOfHits(searcher, createQuery("field", 35, 35), 1); - assertEquals(0, searcher.count(createQuery("field", -90, -90))); - assertEquals(0, searcher.count(createQuery("field", 5, 5))); - assertEquals(0, searcher.count(createQuery("field", 40, 40))); + assertNumberOfHits(searcher, createQuery("field", -90, -90), 0); + assertNumberOfHits(searcher, createQuery("field", 5, 5), 0); + assertNumberOfHits(searcher, createQuery("field", 40, 40), 0); // Test the lower end of the document value range. - assertEquals(2, searcher.count(createQuery("field", -90, -4))); - assertEquals(2, searcher.count(createQuery("field", -80, -4))); - assertEquals(1, searcher.count(createQuery("field", -70, -4))); - assertEquals(2, searcher.count(createQuery("field", -80, -5))); + assertNumberOfHits(searcher, createQuery("field", -90, -4), 2); + assertNumberOfHits(searcher, createQuery("field", -80, -4), 2); + assertNumberOfHits(searcher, createQuery("field", -70, -4), 1); + assertNumberOfHits(searcher, createQuery("field", -80, -5), 2); // Test the upper end of the document value range. - assertEquals(1, searcher.count(createQuery("field", 25, 34))); - assertEquals(2, searcher.count(createQuery("field", 25, 35))); - assertEquals(2, searcher.count(createQuery("field", 25, 36))); - assertEquals(2, searcher.count(createQuery("field", 30, 35))); + assertNumberOfHits(searcher, createQuery("field", 25, 34), 1); + assertNumberOfHits(searcher, createQuery("field", 25, 35), 2); + assertNumberOfHits(searcher, createQuery("field", 25, 36), 2); + assertNumberOfHits(searcher, createQuery("field", 30, 35), 2); // Test multiple occurrences of the same value. - assertEquals(2, searcher.count(createQuery("field", -4, 4))); - assertEquals(2, searcher.count(createQuery("field", -4, 0))); - assertEquals(2, searcher.count(createQuery("field", 0, 4))); - assertEquals(3, searcher.count(createQuery("field", 0, 30))); + assertNumberOfHits(searcher, createQuery("field", -4, 4), 2); + assertNumberOfHits(searcher, createQuery("field", -4, 0), 2); + assertNumberOfHits(searcher, createQuery("field", 0, 4), 2); + assertNumberOfHits(searcher, createQuery("field", 0, 30), 3); // Test ranges that span all documents. - assertEquals(6, searcher.count(createQuery("field", -80, 35))); - assertEquals(6, searcher.count(createQuery("field", -90, 40))); + assertNumberOfHits(searcher, createQuery("field", -80, 35), 6); + assertNumberOfHits(searcher, createQuery("field", -90, 40), 6); writer.close(); reader.close(); dir.close(); } + private static void assertNumberOfHits(IndexSearcher searcher, Query query, int numberOfHits) + throws IOException { + assertEquals( + numberOfHits, searcher.search(query, new TotalHitCountCollectorManager()).intValue()); + assertEquals(numberOfHits, searcher.count(query)); + } + public void testIndexSortDocValuesWithOddLength() throws Exception { testIndexSortDocValuesWithOddLength(false); testIndexSortDocValuesWithOddLength(true); @@ -229,38 +237,38 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas IndexSearcher searcher = newSearcher(reader); // Test ranges consisting of one value. - assertEquals(1, searcher.count(createQuery("field", -80, -80))); - assertEquals(1, searcher.count(createQuery("field", -5, -5))); - assertEquals(2, searcher.count(createQuery("field", 0, 0))); - assertEquals(1, searcher.count(createQuery("field", 5, 5))); - assertEquals(1, searcher.count(createQuery("field", 30, 30))); - assertEquals(1, searcher.count(createQuery("field", 35, 35))); + assertNumberOfHits(searcher, createQuery("field", -80, -80), 1); + assertNumberOfHits(searcher, createQuery("field", -5, -5), 1); + assertNumberOfHits(searcher, createQuery("field", 0, 0), 2); + assertNumberOfHits(searcher, createQuery("field", 5, 5), 1); + assertNumberOfHits(searcher, createQuery("field", 30, 30), 1); + assertNumberOfHits(searcher, createQuery("field", 35, 35), 1); - assertEquals(0, searcher.count(createQuery("field", -90, -90))); - assertEquals(0, searcher.count(createQuery("field", 6, 6))); - assertEquals(0, searcher.count(createQuery("field", 40, 40))); + assertNumberOfHits(searcher, createQuery("field", -90, -90), 0); + assertNumberOfHits(searcher, createQuery("field", 6, 6), 0); + assertNumberOfHits(searcher, createQuery("field", 40, 40), 0); // Test the lower end of the document value range. - assertEquals(2, searcher.count(createQuery("field", -90, -4))); - assertEquals(2, searcher.count(createQuery("field", -80, -4))); - assertEquals(1, searcher.count(createQuery("field", -70, -4))); - assertEquals(2, searcher.count(createQuery("field", -80, -5))); + assertNumberOfHits(searcher, createQuery("field", -90, -4), 2); + assertNumberOfHits(searcher, createQuery("field", -80, -4), 2); + assertNumberOfHits(searcher, createQuery("field", -70, -4), 1); + assertNumberOfHits(searcher, createQuery("field", -80, -5), 2); // Test the upper end of the document value range. - assertEquals(1, searcher.count(createQuery("field", 25, 34))); - assertEquals(2, searcher.count(createQuery("field", 25, 35))); - assertEquals(2, searcher.count(createQuery("field", 25, 36))); - assertEquals(2, searcher.count(createQuery("field", 30, 35))); + assertNumberOfHits(searcher, createQuery("field", 25, 34), 1); + assertNumberOfHits(searcher, createQuery("field", 25, 35), 2); + assertNumberOfHits(searcher, createQuery("field", 25, 36), 2); + assertNumberOfHits(searcher, createQuery("field", 30, 35), 2); // Test multiple occurrences of the same value. - assertEquals(2, searcher.count(createQuery("field", -4, 4))); - assertEquals(2, searcher.count(createQuery("field", -4, 0))); - assertEquals(2, searcher.count(createQuery("field", 0, 4))); - assertEquals(4, searcher.count(createQuery("field", 0, 30))); + assertNumberOfHits(searcher, createQuery("field", -4, 4), 2); + assertNumberOfHits(searcher, createQuery("field", -4, 0), 2); + assertNumberOfHits(searcher, createQuery("field", 0, 4), 2); + assertNumberOfHits(searcher, createQuery("field", 0, 30), 4); // Test ranges that span all documents. - assertEquals(7, searcher.count(createQuery("field", -80, 35))); - assertEquals(7, searcher.count(createQuery("field", -90, 40))); + assertNumberOfHits(searcher, createQuery("field", -80, 35), 7); + assertNumberOfHits(searcher, createQuery("field", -90, 40), 7); writer.close(); reader.close(); @@ -285,10 +293,10 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas DirectoryReader reader = writer.getReader(); IndexSearcher searcher = newSearcher(reader); - assertEquals(1, searcher.count(createQuery("field", 42, 43))); - assertEquals(1, searcher.count(createQuery("field", 42, 42))); - assertEquals(0, searcher.count(createQuery("field", 41, 41))); - assertEquals(0, searcher.count(createQuery("field", 43, 43))); + assertNumberOfHits(searcher, createQuery("field", 42, 43), 1); + assertNumberOfHits(searcher, createQuery("field", 42, 42), 1); + assertNumberOfHits(searcher, createQuery("field", 41, 41), 0); + assertNumberOfHits(searcher, createQuery("field", 43, 43), 0); writer.close(); reader.close(); @@ -316,11 +324,11 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas DirectoryReader reader = writer.getReader(); IndexSearcher searcher = newSearcher(reader); - assertEquals(2, searcher.count(createQuery("field", -70, 0))); - assertEquals(2, searcher.count(createQuery("field", -2, 35))); + assertNumberOfHits(searcher, createQuery("field", -70, 0), 2); + assertNumberOfHits(searcher, createQuery("field", -2, 35), 2); - assertEquals(4, searcher.count(createQuery("field", -80, 35))); - assertEquals(4, searcher.count(createQuery("field", Long.MIN_VALUE, Long.MAX_VALUE))); + assertNumberOfHits(searcher, createQuery("field", -80, 35), 4); + assertNumberOfHits(searcher, createQuery("field", Long.MIN_VALUE, Long.MAX_VALUE), 4); writer.close(); reader.close(); @@ -450,6 +458,56 @@ public class TestIndexSortSortedNumericDocValuesRangeQuery extends LuceneTestCas reader.close(); } + public void testCount() throws IOException { + Directory dir = newDirectory(); + IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random())); + Sort indexSort = new Sort(new SortedNumericSortField("field", SortField.Type.LONG)); + iwc.setIndexSort(indexSort); + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + doc.add(new SortedNumericDocValuesField("field", 10)); + writer.addDocument(doc); + IndexReader reader = writer.getReader(); + IndexSearcher searcher = newSearcher(reader); + + Query fallbackQuery = LongPoint.newRangeQuery("field", 1, 42); + Query query = new IndexSortSortedNumericDocValuesRangeQuery("field", 1, 42, fallbackQuery); + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + for (LeafReaderContext context : searcher.getLeafContexts()) { + assertEquals(1, weight.count(context)); + } + + writer.close(); + reader.close(); + dir.close(); + } + + public void testFallbackCount() throws IOException { + Directory dir = newDirectory(); + IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random())); + Sort indexSort = new Sort(new SortedNumericSortField("field", SortField.Type.LONG)); + iwc.setIndexSort(indexSort); + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + doc.add(new SortedNumericDocValuesField("field", 10)); + writer.addDocument(doc); + IndexReader reader = writer.getReader(); + IndexSearcher searcher = newSearcher(reader); + + // we use an unrealistic query that exposes its own Weight#count + Query fallbackQuery = new MatchNoDocsQuery(); + // the index is not sorted on this field, the fallback query is used + Query query = new IndexSortSortedNumericDocValuesRangeQuery("another", 1, 42, fallbackQuery); + Weight weight = query.createWeight(searcher, ScoreMode.COMPLETE, 1.0f); + for (LeafReaderContext context : searcher.getLeafContexts()) { + assertEquals(0, weight.count(context)); + } + + writer.close(); + reader.close(); + dir.close(); + } + private Document createDocument(String field, long value) { Document doc = new Document(); doc.add(new SortedNumericDocValuesField(field, value));