diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 009241ffafe..1b7d56bc783 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -77,6 +77,9 @@ API Changes should do likewise for analysis components for tokenized text, or else changes to the encoding in future versions may be incompatible with older indexes. (Chongchen Chen, David Smiley) +* LUCENE-8956: QueryRescorer now only sorts the first topN hits instead of all + initial hits. (Paul Sanwald via Adrien Grand) + New Features * LUCENE-8936: Add SpanishMinimalStemFilter (vinod kumar via Tomoko Uchida) diff --git a/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java b/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java index b3452f43881..ce5b16a213a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java +++ b/lucene/core/src/java/org/apache/lucene/search/QueryRescorer.java @@ -23,6 +23,7 @@ import java.util.Comparator; import java.util.List; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.util.ArrayUtil; /** A {@link Rescorer} that uses a provided Query to assign * scores to the first-pass hits. @@ -50,6 +51,7 @@ public abstract class QueryRescorer extends Rescorer { @Override public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs, int topN) throws IOException { ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone(); + Arrays.sort(hits, new Comparator() { @Override @@ -109,32 +111,31 @@ public abstract class QueryRescorer extends Rescorer { hitUpto++; } - // TODO: we should do a partial sort (of only topN) - // instead, but typically the number of hits is - // smallish: - Arrays.sort(hits, - new Comparator() { - @Override - public int compare(ScoreDoc a, ScoreDoc b) { - // Sort by score descending, then docID ascending: - if (a.score > b.score) { - return -1; - } else if (a.score < b.score) { - return 1; - } else { - // This subtraction can't overflow int - // because docIDs are >= 0: - return a.doc - b.doc; - } - } - }); + Comparator sortDocComparator = new Comparator() { + @Override + public int compare(ScoreDoc a, ScoreDoc b) { + // Sort by score descending, then docID ascending: + if (a.score > b.score) { + return -1; + } else if (a.score < b.score) { + return 1; + } else { + // This subtraction can't overflow int + // because docIDs are >= 0: + return a.doc - b.doc; + } + } + }; if (topN < hits.length) { + ArrayUtil.select(hits, 0, hits.length, topN, sortDocComparator); ScoreDoc[] subset = new ScoreDoc[topN]; System.arraycopy(hits, 0, subset, 0, topN); hits = subset; } + Arrays.sort(hits, sortDocComparator); + return new TopDocs(firstPassTopDocs.totalHits, hits); } diff --git a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java index f6bab105911..9247aa636a6 100644 --- a/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/ArrayUtil.java @@ -479,12 +479,23 @@ public final class ArrayUtil { timSort(a, 0, a.length); } - /** Reorganize {@code arr[from:to[} so that the element at offset k is at the - * same position as if {@code arr[from:to[} was sorted, and all elements on - * its left are less than or equal to it, and all elements on its right are - * greater than or equal to it. - * This runs in linear time on average and in {@code n log(n)} time in the - * worst case.*/ + /** + * Reorganize {@code arr[from:to[} so that the element at offset k is at the + * same position as if {@code arr[from:to]} was sorted, and all elements on + * its left are less than or equal to it, and all elements on its right are + * greater than or equal to it. + * + * This runs in linear time on average and in {@code n log(n)} time in the + * worst case. + * + * @param arr Array to be re-organized. + * @param from Starting index for re-organization. Elements before this index + * will be left as is. + * @param to Ending index. Elements after this index will be left as is. + * @param k Index of element to sort from. Value must be less than 'to' and greater than or equal to 'from'. + * @param comparator Comparator to use for sorting + * + */ public static void select(T[] arr, int from, int to, int k, Comparator comparator) { new IntroSelector() { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java b/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java index f885f568c42..e603fad1fb1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestQueryRescorer.java @@ -20,7 +20,9 @@ package org.apache.lucene.search; import java.io.IOException; import java.util.Arrays; import java.util.Comparator; +import java.util.List; +import com.carrotsearch.randomizedtesting.generators.RandomPicks; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; import org.apache.lucene.document.NumericDocValuesField; @@ -30,11 +32,13 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.Term; import org.apache.lucene.search.BooleanClause.Occur; +import org.apache.lucene.search.similarities.BM25Similarity; import org.apache.lucene.search.similarities.ClassicSimilarity; import org.apache.lucene.search.spans.SpanNearQuery; import org.apache.lucene.search.spans.SpanQuery; import org.apache.lucene.search.spans.SpanTermQuery; import org.apache.lucene.store.Directory; +import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.LuceneTestCase; import org.apache.lucene.util.TestUtil; @@ -54,6 +58,94 @@ public class TestQueryRescorer extends LuceneTestCase { return LuceneTestCase.newIndexWriterConfig().setSimilarity(new ClassicSimilarity()); } + static List dictionary = Arrays.asList("river","quick","brown","fox","jumped","lazy","fence"); + + String randomSentence() { + final int length = random().nextInt(10); + StringBuilder sentence = new StringBuilder(dictionary.get(0)+" "); + for (int i = 0; i < length; i++) { + sentence.append(dictionary.get(random().nextInt(dictionary.size()-1))+" "); + } + return sentence.toString(); + } + + private IndexReader publishDocs(int numDocs, String fieldName, Directory dir) throws Exception { + + RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig()); + for (int i = 0; i < numDocs; i++) { + Document d = new Document(); + d.add(newStringField("id", Integer.toString(i), Field.Store.YES)); + d.add(newTextField(fieldName, randomSentence(), Field.Store.NO)); + w.addDocument(d); + } + IndexReader reader = w.getReader(); + w.close(); + return reader; + } + + public void testRescoreOfASubsetOfHits() throws Exception { + Directory dir = newDirectory(); + int numDocs = 100; + String fieldName = "field"; + IndexReader reader = publishDocs(numDocs, fieldName, dir); + + // Construct a query that will get numDocs hits. + String wordOne = dictionary.get(0); + TermQuery termQuery = new TermQuery(new Term(fieldName, wordOne)); + IndexSearcher searcher = getSearcher(reader); + searcher.setSimilarity(new BM25Similarity()); + TopDocs hits = searcher.search(termQuery, numDocs); + + // Next, use a more specific phrase query that will return different scores + // from the above term query + String wordTwo = RandomPicks.randomFrom(random(), dictionary); + PhraseQuery phraseQuery = new PhraseQuery(1, fieldName, wordOne, wordTwo); + + // rescore, requesting a smaller topN + int topN = random().nextInt(numDocs-1); + TopDocs phraseQueryHits = QueryRescorer.rescore(searcher, hits, phraseQuery, 2.0, topN); + assertEquals(topN, phraseQueryHits.scoreDocs.length); + + for (int i = 1; i < phraseQueryHits.scoreDocs.length; i++) { + assertTrue(phraseQueryHits.scoreDocs[i].score <= phraseQueryHits.scoreDocs[i-1].score); + } + reader.close(); + dir.close(); + } + + public void testRescoreIsIdempotent() throws Exception { + Directory dir = newDirectory(); + int numDocs = 100; + String fieldName = "field"; + IndexReader reader = publishDocs(numDocs, fieldName, dir); + + // Construct a query that will get numDocs hits. + String wordOne = dictionary.get(0); + TermQuery termQuery = new TermQuery(new Term(fieldName, wordOne)); + IndexSearcher searcher = getSearcher(reader); + searcher.setSimilarity(new BM25Similarity()); + TopDocs hits1 = searcher.search(termQuery, numDocs); + TopDocs hits2 = searcher.search(termQuery, numDocs); + + // Next, use a more specific phrase query that will return different scores + // from the above term query + String wordTwo = RandomPicks.randomFrom(random(), dictionary); + PhraseQuery phraseQuery = new PhraseQuery(1, fieldName, wordOne, wordTwo); + + // rescore, requesting the same hits as topN + int topN = numDocs; + TopDocs firstRescoreHits = QueryRescorer.rescore(searcher, hits1, phraseQuery, 2.0, topN); + + // now rescore again, where topN is less than numDocs + topN = random().nextInt(numDocs-1); + ScoreDoc[] secondRescoreHits = QueryRescorer.rescore(searcher, hits2, phraseQuery, 2.0, topN).scoreDocs; + ScoreDoc[] expectedTopNScoreDocs = ArrayUtil.copyOfSubArray(firstRescoreHits.scoreDocs, 0, topN); + CheckHits.checkEqual(phraseQuery, expectedTopNScoreDocs, secondRescoreHits); + + reader.close(); + dir.close(); + } + public void testBasic() throws Exception { Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig());