LUCENE-8412: TopFieldCollector no longer takes a trackDocScores option.

This commit is contained in:
Adrien Grand 2018-07-23 09:05:02 +02:00
parent 34686c00dd
commit 55bfadbce1
32 changed files with 346 additions and 264 deletions

View File

@ -60,9 +60,13 @@ API Changes
no longer have an option to compute the maximum score when sorting by field.
(Adrien Grand)
* LUCENE-8411: TopFieldCollector no longer takes a fillFields options, it now
* LUCENE-8411: TopFieldCollector no longer takes a fillFields option, it now
always fills fields. (Adrien Grand)
* LUCENE-8412: TopFieldCollector no longer takes a trackDocScores option. Scores
need to be set on top hits via TopFieldCollector#populateScores instead.
(Adrien Grand)
Changes in Runtime Behavior
* LUCENE-8333: Switch MoreLikeThis.setMaxDocFreqPct to use maxDoc instead of

View File

@ -82,3 +82,10 @@ all matches.
Because filling sort values doesn't have a significant overhead, the fillFields
option has been removed from TopFieldCollector factory methods. Everything
behaves as if it was set to true.
## TopFieldCollector no longer takes a trackDocScores option ##
Computing scores at collection time is less efficient than running a second
request in order to only compute scores for documents that made it to the top
hits. As a consequence, the trackDocScores option has been removed and can be
replaced with the new TopFieldCollector#populateScores helper method.

View File

@ -112,7 +112,6 @@ public abstract class ReadTask extends PerfTask {
// Weight public again, we can go back to
// pulling the Weight ourselves:
TopFieldCollector collector = TopFieldCollector.create(sort, numHits,
withScore(),
withTotalHits());
searcher.search(q, collector);
hits = collector.topDocs();
@ -208,12 +207,6 @@ public abstract class ReadTask extends PerfTask {
*/
public abstract boolean withTraverse();
/** Whether scores should be computed (only useful with
* field sort) */
public boolean withScore() {
return true;
}
/** Whether totalHits should be computed (only useful with
* field sort) */
public boolean withTotalHits() {

View File

@ -29,7 +29,6 @@ import org.apache.lucene.search.SortField;
*/
public class SearchWithSortTask extends ReadTask {
private boolean doScore = true;
private Sort sort;
public SearchWithSortTask(PerfRunData runData) {
@ -60,9 +59,6 @@ public class SearchWithSortTask extends ReadTask {
sortField0 = SortField.FIELD_DOC;
} else if (field.equals("score")) {
sortField0 = SortField.FIELD_SCORE;
} else if (field.equals("noscore")) {
doScore = false;
continue;
} else {
int index = field.lastIndexOf(":");
String fieldName;
@ -115,11 +111,6 @@ public class SearchWithSortTask extends ReadTask {
public boolean withWarm() {
return false;
}
@Override
public boolean withScore() {
return doScore;
}
@Override
public Sort getSort() {

View File

@ -513,7 +513,7 @@ public class IndexSearcher {
@Override
public TopFieldCollector newCollector() throws IOException {
// TODO: don't pay the price for accurate hit counts by default
return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, doDocScores, true);
return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, true);
}
@Override
@ -528,7 +528,11 @@ public class IndexSearcher {
};
return search(query, manager);
TopFieldDocs topDocs = search(query, manager);
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
return topDocs;
}
/**

View File

@ -44,17 +44,12 @@ public class SortRescorer extends Rescorer {
// Copy ScoreDoc[] and sort by ascending docID:
ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
Arrays.sort(hits,
new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
Comparator<ScoreDoc> docIdComparator = Comparator.comparingInt(sd -> sd.doc);
Arrays.sort(hits, docIdComparator);
List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
TopFieldCollector collector = TopFieldCollector.create(sort, topN, true, true);
TopFieldCollector collector = TopFieldCollector.create(sort, topN, true);
// Now merge sort docIDs from hits, with reader's leaves:
int hitUpto = 0;
@ -90,7 +85,15 @@ public class SortRescorer extends Rescorer {
hitUpto++;
}
return collector.topDocs();
TopDocs rescoredDocs = collector.topDocs();
// set scores from the original score docs
assert hits.length == rescoredDocs.scoreDocs.length;
ScoreDoc[] rescoredDocsClone = rescoredDocs.scoreDocs.clone();
Arrays.sort(rescoredDocsClone, docIdComparator);
for (int i = 0; i < rescoredDocsClone.length; ++i) {
rescoredDocsClone[i].score = hits[i].score;
}
return rescoredDocs;
}
@Override

View File

@ -19,16 +19,20 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.FieldValueHitQueue.Entry;
import org.apache.lucene.util.FutureObjects;
import org.apache.lucene.util.PriorityQueue;
/**
* A {@link Collector} that sorts by {@link SortField} using
* {@link FieldComparator}s.
* <p>
* See the {@link #create(org.apache.lucene.search.Sort, int, boolean, boolean)} method
* See the {@link #create(org.apache.lucene.search.Sort, int, boolean)} method
* for instantiating a TopFieldCollector.
*
* @lucene.experimental
@ -44,10 +48,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final LeafFieldComparator comparator;
final int reverseMul;
final boolean mayNeedScoresTwice;
Scorer scorer;
MultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul, boolean mayNeedScoresTwice) {
MultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul) {
if (comparators.length == 1) {
this.reverseMul = reverseMul[0];
this.comparator = comparators[0];
@ -55,14 +58,10 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
this.reverseMul = 1;
this.comparator = new MultiLeafFieldComparator(comparators, reverseMul);
}
this.mayNeedScoresTwice = mayNeedScoresTwice;
}
@Override
public void setScorer(Scorer scorer) throws IOException {
if (mayNeedScoresTwice && scorer instanceof ScoreCachingWrappingScorer == false) {
scorer = new ScoreCachingWrappingScorer(scorer);
}
comparator.setScorer(scorer);
this.scorer = scorer;
}
@ -93,20 +92,12 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final Sort sort;
final FieldValueHitQueue<Entry> queue;
final boolean trackDocScores;
final boolean mayNeedScoresTwice;
final boolean trackTotalHits;
public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits,
boolean trackDocScores, boolean trackTotalHits) {
super(queue, numHits, sort.needsScores() || trackDocScores);
public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits, boolean trackTotalHits) {
super(queue, numHits, sort.needsScores());
this.sort = sort;
this.queue = queue;
this.trackDocScores = trackDocScores;
// If one of the sort fields needs scores, and if we also track scores, then
// we might call scorer.score() several times per doc so wrapping the scorer
// to cache scores would help
this.mayNeedScoresTwice = sort.needsScores() && trackDocScores;
this.trackTotalHits = trackTotalHits;
}
@ -122,7 +113,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
canEarlyTerminate(sort, indexSort);
final int initialTotalHits = totalHits;
return new MultiComparatorLeafCollector(comparators, reverseMul, mayNeedScoresTwice) {
return new MultiComparatorLeafCollector(comparators, reverseMul) {
@Override
public void collect(int doc) throws IOException {
@ -146,10 +137,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
}
}
if (trackDocScores) {
score = scorer.score();
}
// This hit is competitive - replace bottom element in queue & adjustTop
comparator.copy(bottom.slot, doc);
updateBottom(doc, score);
@ -158,10 +145,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// Startup transient: queue hasn't gathered numHits yet
final int slot = totalHits - 1;
if (trackDocScores) {
score = scorer.score();
}
// Copy hit into queue
comparator.copy(slot, doc);
add(slot, doc, score);
@ -184,19 +167,15 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final Sort sort;
int collectedHits;
final FieldValueHitQueue<Entry> queue;
final boolean trackDocScores;
final FieldDoc after;
final boolean mayNeedScoresTwice;
final boolean trackTotalHits;
public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits,
boolean trackDocScores, boolean trackTotalHits) {
super(queue, numHits, trackDocScores || sort.needsScores());
boolean trackTotalHits) {
super(queue, numHits, sort.needsScores());
this.sort = sort;
this.queue = queue;
this.trackDocScores = trackDocScores;
this.after = after;
this.mayNeedScoresTwice = sort.needsScores() && trackDocScores;
this.trackTotalHits = trackTotalHits;
FieldComparator<?>[] comparators = queue.comparators;
@ -217,7 +196,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
indexSort != null &&
canEarlyTerminate(sort, indexSort);
final int initialTotalHits = totalHits;
return new MultiComparatorLeafCollector(queue.getComparators(context), queue.getReverseMul(), mayNeedScoresTwice) {
return new MultiComparatorLeafCollector(queue.getComparators(context), queue.getReverseMul()) {
@Override
public void collect(int doc) throws IOException {
@ -256,10 +235,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// This hit is competitive - replace bottom element in queue & adjustTop
comparator.copy(bottom.slot, doc);
// Compute score only if it is competitive.
if (trackDocScores) {
score = scorer.score();
}
updateBottom(doc, score);
comparator.setBottom(bottom.slot);
@ -272,10 +247,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// Copy hit into queue
comparator.copy(slot, doc);
// Compute score only if it is competitive.
if (trackDocScores) {
score = scorer.score();
}
bottom = pq.add(new Entry(slot, docBase + doc, score));
queueFull = collectedHits == numHits;
if (queueFull) {
@ -325,13 +296,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the sort criteria (SortFields).
* @param numHits
* the number of results to collect.
* @param trackDocScores
* specifies whether document scores should be tracked and set on the
* results. Note that if set to false, then the results' scores will
* be set to Float.NaN. Setting this to true affects performance, as
* it incurs the score computation on each competitive result.
* Therefore if document scores are not required by the application,
* it is recommended to set it to false.
* @param trackTotalHits
* specifies whether the total number of hits should be tracked. If
* set to false, the value of {@link TopFieldDocs#totalHits} will be
@ -339,9 +303,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* @return a {@link TopFieldCollector} instance which will sort the results by
* the sort criteria.
*/
public static TopFieldCollector create(Sort sort, int numHits,
boolean trackDocScores, boolean trackTotalHits) {
return create(sort, numHits, null, trackDocScores, trackTotalHits);
public static TopFieldCollector create(Sort sort, int numHits, boolean trackTotalHits) {
return create(sort, numHits, null, trackTotalHits);
}
/**
@ -358,14 +321,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the number of results to collect.
* @param after
* only hits after this FieldDoc will be collected
* @param trackDocScores
* specifies whether document scores should be tracked and set on the
* results. Note that if set to false, then the results' scores will
* be set to Float.NaN. Setting this to true affects performance, as
* it incurs the score computation on each competitive result.
* Therefore if document scores are not required by the application,
* it is recommended to set it to false.
* <code>trackDocScores</code> to true as well.
* @param trackTotalHits
* specifies whether the total number of hits should be tracked. If
* set to false, the value of {@link TopFieldDocs#totalHits} will be
@ -374,7 +329,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the sort criteria.
*/
public static TopFieldCollector create(Sort sort, int numHits, FieldDoc after,
boolean trackDocScores, boolean trackTotalHits) {
boolean trackTotalHits) {
if (sort.fields.length == 0) {
throw new IllegalArgumentException("Sort must contain at least one field");
@ -387,7 +342,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits);
if (after == null) {
return new SimpleFieldCollector(sort, queue, numHits, trackDocScores, trackTotalHits);
return new SimpleFieldCollector(sort, queue, numHits, trackTotalHits);
} else {
if (after.fields == null) {
throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
@ -397,7 +352,46 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length);
}
return new PagingFieldCollector(sort, queue, after, numHits, trackDocScores, trackTotalHits);
return new PagingFieldCollector(sort, queue, after, numHits, trackTotalHits);
}
}
/**
* Populate {@link ScoreDoc#score scores} of the given {@code topDocs}.
* @param topDocs the top docs to populate
* @param searcher the index searcher that has been used to compute {@code topDocs}
* @param query the query that has been used to compute {@code topDocs}
* @throws IllegalArgumentException if there is evidence that {@code topDocs}
* have been computed against a different searcher or a different query.
* @lucene.experimental
*/
public static void populateScores(ScoreDoc[] topDocs, IndexSearcher searcher, Query query) throws IOException {
// Get the score docs sorted in doc id order
topDocs = topDocs.clone();
Arrays.sort(topDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
final Weight weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE, 1);
List<LeafReaderContext> contexts = searcher.getIndexReader().leaves();
LeafReaderContext currentContext = null;
Scorer currentScorer = null;
for (ScoreDoc scoreDoc : topDocs) {
if (currentContext == null || scoreDoc.doc >= currentContext.docBase + currentContext.reader().maxDoc()) {
FutureObjects.checkIndex(scoreDoc.doc, searcher.getIndexReader().maxDoc());
int newContextIndex = ReaderUtil.subIndex(scoreDoc.doc, contexts);
currentContext = contexts.get(newContextIndex);
final ScorerSupplier scorerSupplier = weight.scorerSupplier(currentContext);
if (scorerSupplier == null) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
currentScorer = scorerSupplier.get(1); // random-access
}
final int leafDoc = scoreDoc.doc - currentContext.docBase;
assert leafDoc >= 0;
final int advanced = currentScorer.iterator().advance(leafDoc);
if (leafDoc != advanced) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
scoreDoc.score = currentScorer.score();
}
}

View File

@ -2324,11 +2324,11 @@ public class TestIndexSorting extends LuceneTestCase {
System.out.println("TEST: iter=" + iter + " numHits=" + numHits);
}
TopFieldCollector c1 = TopFieldCollector.create(sort, numHits, true, true);
TopFieldCollector c1 = TopFieldCollector.create(sort, numHits, true);
s1.search(new MatchAllDocsQuery(), c1);
TopDocs hits1 = c1.topDocs();
TopFieldCollector c2 = TopFieldCollector.create(sort, numHits, true, false);
TopFieldCollector c2 = TopFieldCollector.create(sort, numHits, false);
s2.search(new MatchAllDocsQuery(), c2);
TopDocs hits2 = c2.topDocs();

View File

@ -386,10 +386,10 @@ public class TestBoolean2 extends LuceneTestCase {
}
// check diff (randomized) scorers (from AssertingSearcher) produce the same results
TopFieldCollector collector = TopFieldCollector.create(sort, 1000, true, false);
TopFieldCollector collector = TopFieldCollector.create(sort, 1000, false);
searcher.search(q1, collector);
ScoreDoc[] hits1 = collector.topDocs().scoreDocs;
collector = TopFieldCollector.create(sort, 1000, true, false);
collector = TopFieldCollector.create(sort, 1000, false);
searcher.search(q1, collector);
ScoreDoc[] hits2 = collector.topDocs().scoreDocs;
tot+=hits2.length;
@ -402,10 +402,10 @@ public class TestBoolean2 extends LuceneTestCase {
assertEquals(mulFactor*collector.totalHits + NUM_EXTRA_DOCS/2, hits4.totalHits);
// test diff (randomized) scorers produce the same results on bigSearcher as well
collector = TopFieldCollector.create(sort, 1000 * mulFactor, true, false);
collector = TopFieldCollector.create(sort, 1000 * mulFactor, false);
bigSearcher.search(q1, collector);
hits1 = collector.topDocs().scoreDocs;
collector = TopFieldCollector.create(sort, 1000 * mulFactor, true, false);
collector = TopFieldCollector.create(sort, 1000 * mulFactor, false);
bigSearcher.search(q1, collector);
hits2 = collector.topDocs().scoreDocs;
CheckHits.checkEqual(q1, hits1, hits2);

View File

@ -86,7 +86,7 @@ public class TestElevationComparator extends LuceneTestCase {
new SortField(null, SortField.Type.SCORE, reversed)
);
TopDocsCollector<Entry> topCollector = TopFieldCollector.create(sort, 50, true, true);
TopDocsCollector<Entry> topCollector = TopFieldCollector.create(sort, 50, true);
searcher.search(newq.build(), topCollector);
TopDocs topDocs = topCollector.topDocs(0, 10);

View File

@ -146,7 +146,7 @@ public class TestSortRandom extends LuceneTestCase {
}
final int hitCount = TestUtil.nextInt(random, 1, r.maxDoc() + 20);
final RandomQuery f = new RandomQuery(random.nextLong(), random.nextFloat(), docValues);
hits = s.search(f, hitCount, sort, random.nextBoolean());
hits = s.search(f, hitCount, sort, false);
if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter + " " + hits.totalHits + " hits; topN=" + hitCount + "; reverse=" + reverse + "; sortMissingLast=" + sortMissingLast + " sort=" + sort);

View File

@ -281,7 +281,7 @@ public class TestTopDocsMerge extends LuceneTestCase {
topHits = searcher.search(query, numHits);
}
} else {
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true, true);
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true);
searcher.search(query, c);
if (useFrom) {
from = TestUtil.nextInt(random(), 0, numHits - 1);
@ -330,7 +330,7 @@ public class TestTopDocsMerge extends LuceneTestCase {
if (sort == null) {
subHits = subSearcher.search(w, numHits);
} else {
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true, true);
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true);
subSearcher.search(w, c);
subHits = c.topDocs(0, numHits);
}

View File

@ -18,11 +18,14 @@ package org.apache.lucene.search;
import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter;
@ -70,8 +73,7 @@ public class TestTopFieldCollector extends LuceneTestCase {
Sort[] sort = new Sort[] { new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery();
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10,
false, true);
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
is.search(q, tdc);
@ -83,14 +85,13 @@ public class TestTopFieldCollector extends LuceneTestCase {
}
}
public void testSortWithoutScoreTracking() throws Exception {
public void testSort() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery();
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, false,
true);
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
is.search(q, tdc);
@ -110,10 +111,10 @@ public class TestTopFieldCollector extends LuceneTestCase {
// the index is not sorted
TopDocsCollector<Entry> tdc;
if (i % 2 == 0) {
tdc = TopFieldCollector.create(sort, 10, false, false);
tdc = TopFieldCollector.create(sort, 10, false);
} else {
FieldDoc fieldDoc = new FieldDoc(1, Float.NaN, new Object[] { 1 });
tdc = TopFieldCollector.create(sort, 10, fieldDoc, false, false);
tdc = TopFieldCollector.create(sort, 10, fieldDoc, false);
}
is.search(q, tdc);
@ -125,32 +126,13 @@ public class TestTopFieldCollector extends LuceneTestCase {
}
}
}
public void testSortWithScoreTracking() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery();
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true,
true);
is.search(q, tdc);
TopDocs td = tdc.topDocs();
ScoreDoc[] sd = td.scoreDocs;
for(int j = 0; j < sd.length; j++) {
assertTrue(!Float.isNaN(sd[j].score));
}
}
}
public void testSortWithScoreTrackingNoResults() throws Exception {
public void testSortNoResults() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) {
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true, true);
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
TopDocs td = tdc.topDocs();
assertEquals(0, td.totalHits);
}
@ -182,59 +164,112 @@ public class TestTopFieldCollector extends LuceneTestCase {
.build();
final IndexSearcher searcher = new IndexSearcher(reader);
for (Sort sort : new Sort[] {new Sort(SortField.FIELD_SCORE), new Sort(new SortField("f", SortField.Type.SCORE))}) {
for (boolean doDocScores : new boolean[] {false, true}) {
final TopFieldCollector topCollector = TopFieldCollector.create(sort, TestUtil.nextInt(random(), 1, 2), doDocScores, true);
final Collector assertingCollector = new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
final LeafCollector in = topCollector.getLeafCollector(context);
return new FilterLeafCollector(in) {
@Override
public void setScorer(final Scorer scorer) throws IOException {
Scorer s = new Scorer(null) {
final TopFieldCollector topCollector = TopFieldCollector.create(sort, TestUtil.nextInt(random(), 1, 2), true);
final Collector assertingCollector = new Collector() {
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
final LeafCollector in = topCollector.getLeafCollector(context);
return new FilterLeafCollector(in) {
@Override
public void setScorer(final Scorer scorer) throws IOException {
Scorer s = new Scorer(null) {
int lastComputedDoc = -1;
int lastComputedDoc = -1;
@Override
public float score() throws IOException {
if (lastComputedDoc == docID()) {
throw new AssertionError("Score computed twice on " + docID());
}
lastComputedDoc = docID();
return scorer.score();
@Override
public float score() throws IOException {
if (lastComputedDoc == docID()) {
throw new AssertionError("Score computed twice on " + docID());
}
lastComputedDoc = docID();
return scorer.score();
}
@Override
public float getMaxScore(int upTo) throws IOException {
return scorer.getMaxScore(upTo);
}
@Override
public float getMaxScore(int upTo) throws IOException {
return scorer.getMaxScore(upTo);
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public int docID() {
return scorer.docID();
}
@Override
public DocIdSetIterator iterator() {
return scorer.iterator();
}
@Override
public DocIdSetIterator iterator() {
return scorer.iterator();
}
};
super.setScorer(s);
}
};
}
@Override
public ScoreMode scoreMode() {
return topCollector.scoreMode();
}
};
searcher.search(query, assertingCollector);
}
};
super.setScorer(s);
}
};
}
@Override
public ScoreMode scoreMode() {
return topCollector.scoreMode();
}
};
searcher.search(query, assertingCollector);
}
reader.close();
w.close();
dir.close();
}
public void testPopulateScores() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
TextField field = new TextField("f", "foo bar", Store.NO);
doc.add(field);
NumericDocValuesField sortField = new NumericDocValuesField("sort", 0);
doc.add(sortField);
w.addDocument(doc);
field.setStringValue("");
sortField.setLongValue(3);
w.addDocument(doc);
field.setStringValue("foo foo bar");
sortField.setLongValue(2);
w.addDocument(doc);
w.flush();
field.setStringValue("foo");
sortField.setLongValue(2);
w.addDocument(doc);
field.setStringValue("bar bar bar");
sortField.setLongValue(0);
w.addDocument(doc);
IndexReader reader = w.getReader();
w.close();
IndexSearcher searcher = newSearcher(reader);
for (String queryText : new String[] { "foo", "bar" }) {
Query query = new TermQuery(new Term("f", queryText));
for (boolean reverse : new boolean[] {false, true}) {
ScoreDoc[] sortedByDoc = searcher.search(query, 10).scoreDocs;
Arrays.sort(sortedByDoc, Comparator.comparingInt(sd -> sd.doc));
Sort sort = new Sort(new SortField("sort", SortField.Type.LONG, reverse));
ScoreDoc[] sortedByField = searcher.search(query, 10, sort).scoreDocs;
ScoreDoc[] sortedByFieldClone = sortedByField.clone();
TopFieldCollector.populateScores(sortedByFieldClone, searcher, query);
for (int i = 0; i < sortedByFieldClone.length; ++i) {
assertEquals(sortedByFieldClone[i].doc, sortedByField[i].doc);
assertSame(((FieldDoc) sortedByFieldClone[i]).fields, ((FieldDoc) sortedByField[i]).fields);
assertEquals(sortedByFieldClone[i].score,
sortedByDoc[Arrays.binarySearch(sortedByDoc, sortedByFieldClone[i], Comparator.comparingInt(sd -> sd.doc))].score, 0f);
}
}
}
reader.close();
dir.close();
}
}

View File

@ -136,9 +136,8 @@ public class TestTopFieldCollectorEarlyTermination extends LuceneTestCase {
} else {
after = null;
}
final boolean trackDocScores = random().nextBoolean();
final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, after, trackDocScores, true);
final TopFieldCollector collector2 = TopFieldCollector.create(sort, numHits, after, trackDocScores, false);
final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, after, true);
final TopFieldCollector collector2 = TopFieldCollector.create(sort, numHits, after, false);
final Query query;
if (random().nextBoolean()) {

View File

@ -241,7 +241,7 @@ public class DrillSideways {
@Override
public TopFieldCollector newCollector() throws IOException {
return TopFieldCollector.create(sort, fTopN, after, doDocScores, true);
return TopFieldCollector.create(sort, fTopN, after, true);
}
@Override
@ -255,14 +255,22 @@ public class DrillSideways {
};
ConcurrentDrillSidewaysResult<TopFieldDocs> r = search(query, collectorManager);
return new DrillSidewaysResult(r.facets, r.collectorResult);
TopFieldDocs topDocs = r.collectorResult;
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
return new DrillSidewaysResult(r.facets, topDocs);
} else {
final TopFieldCollector hitCollector =
TopFieldCollector.create(sort, fTopN, after, doDocScores, true);
TopFieldCollector.create(sort, fTopN, after, true);
DrillSidewaysResult r = search(query, hitCollector);
return new DrillSidewaysResult(r.facets, hitCollector.topDocs());
TopFieldDocs topDocs = hitCollector.topDocs();
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
return new DrillSidewaysResult(r.facets, topDocs);
}
} else {
return search(after, query, topN);

View File

@ -230,7 +230,6 @@ public class FacetsCollector extends SimpleCollector implements Collector {
}
hitsCollector = TopFieldCollector.create(sort, n,
(FieldDoc) after,
doDocScores,
true); // TODO: can we disable exact hit counts
} else {
hitsCollector = TopScoreDocCollector.create(n, after, true);
@ -238,6 +237,9 @@ public class FacetsCollector extends SimpleCollector implements Collector {
searcher.search(q, MultiCollector.wrap(hitsCollector, fc));
topDocs = hitsCollector.topDocs();
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, q);
}
}
return topDocs;
}

View File

@ -304,7 +304,7 @@ public class BlockGroupingCollector extends SimpleCollector {
collector = TopScoreDocCollector.create(maxDocsPerGroup);
} else {
// Sort by fields
collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, needsScores, true); // TODO: disable exact counts?
collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, true); // TODO: disable exact counts?
}
float groupMaxScore = needsScores ? Float.NEGATIVE_INFINITY : Float.NaN;

View File

@ -50,7 +50,6 @@ public class GroupingSearch {
private int groupDocsOffset;
private int groupDocsLimit = 1;
private boolean includeScores = true;
private boolean includeMaxScore = true;
private Double maxCacheRAMMB;
@ -153,8 +152,7 @@ public class GroupingSearch {
int topNInsideGroup = groupDocsOffset + groupDocsLimit;
TopGroupsCollector secondPassCollector
= new TopGroupsCollector(grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup,
includeScores, includeMaxScore);
= new TopGroupsCollector(grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup, includeMaxScore);
if (cachedCollector != null && cachedCollector.isCached()) {
cachedCollector.replay(secondPassCollector);
@ -173,7 +171,7 @@ public class GroupingSearch {
int topN = groupOffset + groupLimit;
final Query endDocsQuery = searcher.rewrite(this.groupEndDocs);
final Weight groupEndDocs = searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1);
BlockGroupingCollector c = new BlockGroupingCollector(groupSort, topN, includeScores, groupEndDocs);
BlockGroupingCollector c = new BlockGroupingCollector(groupSort, topN, groupSort.needsScores() || sortWithinGroup.needsScores(), groupEndDocs);
searcher.search(query, c);
int topNInsideGroup = groupDocsOffset + groupDocsLimit;
return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup);
@ -268,17 +266,6 @@ public class GroupingSearch {
return this;
}
/**
* Whether to include the scores per doc inside a group.
*
* @param includeScores Whether to include the scores per doc inside a group
* @return <code>this</code>
*/
public GroupingSearch setIncludeScores(boolean includeScores) {
this.includeScores = includeScores;
return this;
}
/**
* Whether to include the score of the most relevant document per group.
*

View File

@ -54,13 +54,12 @@ public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> {
* @param groupSort the order in which groups are returned
* @param withinGroupSort the order in which documents are sorted in each group
* @param maxDocsPerGroup the maximum number of docs to collect for each group
* @param getScores if true, record the scores of all docs in each group
* @param getMaxScores if true, record the maximum score for each group
*/
public TopGroupsCollector(GroupSelector<T> groupSelector, Collection<SearchGroup<T>> groups, Sort groupSort, Sort withinGroupSort,
int maxDocsPerGroup, boolean getScores, boolean getMaxScores) {
int maxDocsPerGroup, boolean getMaxScores) {
super(groupSelector, groups,
new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getScores, getMaxScores));
new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores));
this.groupSort = Objects.requireNonNull(groupSort);
this.withinGroupSort = Objects.requireNonNull(withinGroupSort);
this.maxDocsPerGroup = maxDocsPerGroup;
@ -114,13 +113,13 @@ public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> {
private final boolean needsScores;
TopDocsReducer(Sort withinGroupSort,
int maxDocsPerGroup, boolean getScores, boolean getMaxScores) {
this.needsScores = getScores || getMaxScores || withinGroupSort.needsScores();
int maxDocsPerGroup, boolean getMaxScores) {
this.needsScores = getMaxScores || withinGroupSort.needsScores();
if (withinGroupSort == Sort.RELEVANCE) {
supplier = () -> new TopDocsAndMaxScoreCollector(true, TopScoreDocCollector.create(maxDocsPerGroup), null);
} else {
supplier = () -> {
TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, getScores, true); // TODO: disable exact counts?
TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, true); // TODO: disable exact counts?
MaxScoreCollector maxScoreCollector = getMaxScores ? new MaxScoreCollector() : null;
return new TopDocsAndMaxScoreCollector(false, topDocsCollector, maxScoreCollector);
};

View File

@ -145,7 +145,7 @@ public class TestGrouping extends LuceneTestCase {
final FirstPassGroupingCollector<?> c1 = createRandomFirstPassCollector(groupField, groupSort, 10);
indexSearcher.search(new TermQuery(new Term("content", "random")), c1);
final TopGroupsCollector<?> c2 = createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true, true);
final TopGroupsCollector<?> c2 = createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true);
indexSearcher.search(new TermQuery(new Term("content", "random")), c2);
final TopGroups<?> groups = c2.getTopGroups(0);
@ -218,11 +218,10 @@ public class TestGrouping extends LuceneTestCase {
Sort sortWithinGroup,
int groupOffset,
int maxDocsPerGroup,
boolean getScores,
boolean getMaxScores) throws IOException {
Collection<SearchGroup<T>> searchGroups = firstPassGroupingCollector.getTopGroups(groupOffset);
return new TopGroupsCollector<>(firstPassGroupingCollector.getGroupSelector(), searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getScores, getMaxScores);
return new TopGroupsCollector<>(firstPassGroupingCollector.getGroupSelector(), searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
}
// Basically converts searchGroups from MutableValue to BytesRef if grouping by ValueSource
@ -233,11 +232,10 @@ public class TestGrouping extends LuceneTestCase {
Sort groupSort,
Sort sortWithinGroup,
int maxDocsPerGroup,
boolean getScores,
boolean getMaxScores) throws IOException {
if (firstPassGroupingCollector.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) {
GroupSelector<BytesRef> selector = (GroupSelector<BytesRef>) firstPassGroupingCollector.getGroupSelector();
return new TopGroupsCollector<>(selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup , getScores, getMaxScores);
return new TopGroupsCollector<>(selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
} else {
ValueSource vs = new BytesRefFieldSource(groupField);
List<SearchGroup<MutableValue>> mvalSearchGroups = new ArrayList<>(searchGroups.size());
@ -254,7 +252,7 @@ public class TestGrouping extends LuceneTestCase {
mvalSearchGroups.add(sg);
}
ValueSourceGroupSelector selector = new ValueSourceGroupSelector(vs, new HashMap<>());
return new TopGroupsCollector<>(selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getScores, getMaxScores);
return new TopGroupsCollector<>(selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
}
}
@ -437,7 +435,6 @@ public class TestGrouping extends LuceneTestCase {
private TopGroups<BytesRef> slowGrouping(GroupDoc[] groupDocs,
String searchTerm,
boolean getScores,
boolean getMaxScores,
boolean doAllGroups,
Sort groupSort,
@ -507,7 +504,7 @@ public class TestGrouping extends LuceneTestCase {
for(int docIDX=docOffset; docIDX < docIDXLimit; docIDX++) {
final GroupDoc d = docs.get(docIDX);
final FieldDoc fd;
fd = new FieldDoc(d.id, getScores ? d.score : Float.NaN, fillFields(d, docSort));
fd = new FieldDoc(d.id, Float.NaN, fillFields(d, docSort));
hits[docIDX-docOffset] = fd;
}
} else {
@ -829,14 +826,11 @@ public class TestGrouping extends LuceneTestCase {
}
final String searchTerm = "real" + random().nextInt(3);
boolean getScores = random().nextBoolean();
final boolean getMaxScores = random().nextBoolean();
final Sort groupSort = getRandomSort();
//final Sort groupSort = new Sort(new SortField[] {new SortField("sort1", SortField.STRING), new SortField("id", SortField.INT)});
final Sort docSort = getRandomSort();
getScores |= (groupSort.needsScores() || docSort.needsScores());
final int topNGroups = TestUtil.nextInt(random(), 1, 30);
//final int topNGroups = 10;
final int docsPerGroup = TestUtil.nextInt(random(), 1, 50);
@ -850,7 +844,7 @@ public class TestGrouping extends LuceneTestCase {
final boolean doCache = random().nextBoolean();
final boolean doAllGroups = random().nextBoolean();
if (VERBOSE) {
System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " dF=" + r.docFreq(new Term("content", searchTerm)) +" dFBlock=" + rBlocks.docFreq(new Term("content", searchTerm)) + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getScores=" + getScores + " getMaxScores=" + getMaxScores);
System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " dF=" + r.docFreq(new Term("content", searchTerm)) +" dFBlock=" + rBlocks.docFreq(new Term("content", searchTerm)) + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getMaxScores=" + getMaxScores);
}
String groupField = "group";
@ -935,7 +929,7 @@ public class TestGrouping extends LuceneTestCase {
// Get 1st pass top groups using shards
final TopGroups<BytesRef> topGroupsShards = searchShards(s, shards.subSearchers, query, groupSort, docSort,
groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores, true, true);
groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, true, true);
final TopGroupsCollector<?> c2;
if (topGroups != null) {
@ -946,7 +940,7 @@ public class TestGrouping extends LuceneTestCase {
}
}
c2 = createSecondPassCollector(c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getScores, getMaxScores);
c2 = createSecondPassCollector(c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getMaxScores);
if (doCache) {
if (cCache.isCached()) {
if (VERBOSE) {
@ -977,7 +971,7 @@ public class TestGrouping extends LuceneTestCase {
}
}
final TopGroups<BytesRef> expectedGroups = slowGrouping(groupDocs, searchTerm, getScores, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset);
final TopGroups<BytesRef> expectedGroups = slowGrouping(groupDocs, searchTerm, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset);
if (VERBOSE) {
if (expectedGroups == null) {
@ -1023,17 +1017,16 @@ public class TestGrouping extends LuceneTestCase {
}
}
assertEquals(docIDToID, expectedGroups, groupsResult, true, true, getScores, true);
assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true);
// Confirm merged shards match:
assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, getScores, true);
assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, true);
if (topGroupsShards != null) {
verifyShards(shards.docStarts, topGroupsShards);
}
final boolean needsScores = getScores || getMaxScores || docSort == null;
final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups, needsScores,
sBlocks.createWeight(sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups,
groupSort.needsScores() || docSort.needsScores(), sBlocks.createWeight(sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
final AllGroupsCollector<BytesRef> allGroupsCollector2;
final Collector c4;
if (doAllGroups) {
@ -1079,7 +1072,7 @@ public class TestGrouping extends LuceneTestCase {
// Get shard'd block grouping result:
final TopGroups<BytesRef> topGroupsBlockShards = searchShards(sBlocks, shardsBlocks.subSearchers, query,
groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores, false, false);
groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, false, false);
if (expectedGroups != null) {
// Fixup scores for reader2
@ -1122,8 +1115,8 @@ public class TestGrouping extends LuceneTestCase {
}
}
assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, getScores, false);
assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, getScores, false);
assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, false);
assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, false);
}
r.close();
@ -1146,7 +1139,7 @@ public class TestGrouping extends LuceneTestCase {
}
private TopGroups<BytesRef> searchShards(IndexSearcher topSearcher, ShardSearcher[] subSearchers, Query query, Sort groupSort, Sort docSort, int groupOffset, int topNGroups, int docOffset,
int topNDocs, boolean getScores, boolean getMaxScores, boolean canUseIDV, boolean preFlex) throws Exception {
int topNDocs, boolean getMaxScores, boolean canUseIDV, boolean preFlex) throws Exception {
// TODO: swap in caching, all groups collector hereassertEquals(expected.totalHitCount, actual.totalHitCount);
// too...
@ -1154,7 +1147,7 @@ public class TestGrouping extends LuceneTestCase {
System.out.println("TEST: " + subSearchers.length + " shards: " + Arrays.toString(subSearchers) + " canUseIDV=" + canUseIDV);
}
// Run 1st pass collector to get top groups per shard
final Weight w = topSearcher.createWeight(topSearcher.rewrite(query), getScores || getMaxScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1);
final Weight w = topSearcher.createWeight(topSearcher.rewrite(query), groupSort.needsScores() || docSort.needsScores() || getMaxScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1);
final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>();
List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>();
FirstPassGroupingCollector<?> firstPassCollector = null;
@ -1208,7 +1201,7 @@ public class TestGrouping extends LuceneTestCase {
final TopGroups<BytesRef>[] shardTopGroups = new TopGroups[subSearchers.length];
for(int shardIDX=0;shardIDX<subSearchers.length;shardIDX++) {
final TopGroupsCollector<?> secondPassCollector = createSecondPassCollector(firstPassGroupingCollectors.get(shardIDX),
groupField, mergedTopGroups, groupSort, docSort, docOffset + topNDocs, getScores, getMaxScores);
groupField, mergedTopGroups, groupSort, docSort, docOffset + topNDocs, getMaxScores);
subSearchers[shardIDX].search(w, secondPassCollector);
shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0);
if (VERBOSE) {
@ -1232,7 +1225,7 @@ public class TestGrouping extends LuceneTestCase {
}
}
private void assertEquals(int[] docIDtoID, TopGroups<BytesRef> expected, TopGroups<BytesRef> actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean testScores, boolean idvBasedImplsUsed) {
private void assertEquals(int[] docIDtoID, TopGroups<BytesRef> expected, TopGroups<BytesRef> actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean idvBasedImplsUsed) {
if (expected == null) {
assertNull(actual);
return;
@ -1279,12 +1272,6 @@ public class TestGrouping extends LuceneTestCase {
final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX];
//System.out.println(" actual doc=" + docIDtoID[actualFD.doc] + " score=" + actualFD.score);
assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]);
if (testScores) {
assertEquals(expectedFD.score, actualFD.score, 0.1);
} else {
// TODO: too anal for now
//assertEquals(Float.NaN, actualFD.score);
}
assertArrayEquals(expectedFD.fields, actualFD.fields);
}
}

View File

@ -647,7 +647,7 @@ public class AnalyzingInfixSuggester extends Lookup implements Closeable {
//System.out.println("finalQuery=" + finalQuery);
// Sort by weight, descending:
TopFieldCollector c = TopFieldCollector.create(SORT, num, false, false);
TopFieldCollector c = TopFieldCollector.create(SORT, num, false);
List<LookupResult> results = null;
SearcherManager mgr;
IndexSearcher searcher;

View File

@ -535,7 +535,7 @@ public class ExpandComponent extends SearchComponent implements PluginInfoInitia
DocIdSetIterator iterator = new BitSetIterator(groupBits, 0); // cost is not useful here
int group;
while ((group = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, false, true);
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, true);
groups.put(group, collector);
}
@ -619,7 +619,7 @@ public class ExpandComponent extends SearchComponent implements PluginInfoInitia
Iterator<LongCursor> iterator = groupSet.iterator();
while (iterator.hasNext()) {
LongCursor cursor = iterator.next();
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, false, true);
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, true);
groups.put(cursor.value, collector);
}

View File

@ -1327,6 +1327,7 @@ public class QueryComponent extends SearchComponent
secondPhaseBuilder.addCommandField(
new TopGroupsFieldCommand.Builder()
.setQuery(cmd.getQuery())
.setField(schemaField)
.setGroupSort(groupingSpec.getGroupSort())
.setSortWithinGroup(groupingSpec.getSortWithinGroup())

View File

@ -590,6 +590,14 @@ public class Grouping {
return null;
}
protected void populateScoresIfNecessary() throws IOException {
if (needScores) {
for (GroupDocs<?> groups : result.groups) {
TopFieldCollector.populateScores(groups.scoreDocs, searcher, query);
}
}
}
protected NamedList commonResponse() {
NamedList groupResult = new SimpleOrderedMap();
grouped.add(key, groupResult); // grouped={ key={
@ -747,7 +755,7 @@ public class Grouping {
groupedDocsToCollect = Math.max(groupedDocsToCollect, 1);
Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE;
secondPass = new TopGroupsCollector<>(new TermGroupSelector(groupBy),
topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores, needScores
topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores
);
if (totalCount == TotalCount.grouped) {
@ -766,7 +774,10 @@ public class Grouping {
@Override
protected void finish() throws IOException {
result = secondPass != null ? secondPass.getTopGroups(0) : null;
if (secondPass != null) {
result = secondPass.getTopGroups(0);
populateScoresIfNecessary();
}
if (main) {
mainResult = createSimpleResponse();
return;
@ -850,7 +861,7 @@ public class Grouping {
if (withinGroupSort == null || withinGroupSort.equals(Sort.RELEVANCE)) {
subCollector = topCollector = TopScoreDocCollector.create(groupDocsToCollect);
} else {
topCollector = TopFieldCollector.create(searcher.weightSort(withinGroupSort), groupDocsToCollect, needScores, true);
topCollector = TopFieldCollector.create(searcher.weightSort(withinGroupSort), groupDocsToCollect, true);
if (needScores) {
maxScoreCollector = new MaxScoreCollector();
subCollector = MultiCollector.wrap(topCollector, maxScoreCollector);
@ -952,7 +963,7 @@ public class Grouping {
groupdDocsToCollect = Math.max(groupdDocsToCollect, 1);
Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE;
secondPass = new TopGroupsCollector<>(newSelector(),
topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores, needScores
topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores
);
if (totalCount == TotalCount.grouped) {
@ -971,7 +982,10 @@ public class Grouping {
@Override
protected void finish() throws IOException {
result = secondPass != null ? secondPass.getTopGroups(0) : null;
if (secondPass != null) {
result = secondPass.getTopGroups(0);
populateScoresIfNecessary();
}
if (main) {
mainResult = createSimpleResponse();
return;

View File

@ -27,6 +27,7 @@ import com.carrotsearch.hppc.IntIntHashMap;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
@ -49,6 +50,8 @@ public class ReRankCollector extends TopDocsCollector {
final private int length;
final private Set<BytesRef> boostedPriority; // order is the "priority"
final private Rescorer reRankQueryRescorer;
final private Sort sort;
final private Query query;
public ReRankCollector(int reRankDocs,
@ -61,13 +64,15 @@ public class ReRankCollector extends TopDocsCollector {
this.reRankDocs = reRankDocs;
this.length = length;
this.boostedPriority = boostedPriority;
this.query = cmd.getQuery();
Sort sort = cmd.getSort();
if(sort == null) {
this.sort = null;
this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length));
} else {
sort = sort.rewrite(searcher);
this.sort = sort = sort.rewrite(searcher);
//scores are needed for Rescorer (regardless of whether sort needs it)
this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), true, true);
this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), true);
}
this.searcher = searcher;
this.reRankQueryRescorer = reRankQueryRescorer;
@ -84,7 +89,7 @@ public class ReRankCollector extends TopDocsCollector {
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE; // since the scores will be needed by Rescorer as input regardless of mainCollector
return sort == null || sort.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
}
public TopDocs topDocs(int start, int howMany) {
@ -97,6 +102,10 @@ public class ReRankCollector extends TopDocsCollector {
return mainDocs;
}
if (sort != null) {
TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query);
}
ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)];
System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length);

View File

@ -1532,12 +1532,11 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
return TopScoreDocCollector.create(len);
} else {
// we have a sort
final boolean needScores = (cmd.getFlags() & GET_SCORES) != 0;
final Sort weightedSort = weightSort(cmd.getSort());
final CursorMark cursor = cmd.getCursorMark();
final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null);
return TopFieldCollector.create(weightedSort, len, searchAfter, needScores, true);
return TopFieldCollector.create(weightedSort, len, searchAfter, true);
}
}
@ -1624,6 +1623,9 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
totalHits = topCollector.getTotalHits();
TopDocs topDocs = topCollector.topDocs(0, len);
if (cmd.getSort() != null && query instanceof RankQuery == false && (cmd.getFlags() & GET_SCORES) != 0) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
populateNextCursorMarkFromTopDocs(qr, cmd, topDocs);
maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f;
@ -1732,6 +1734,9 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
assert (totalHits == set.size());
TopDocs topDocs = topCollector.topDocs(0, len);
if (cmd.getSort() != null && query instanceof RankQuery == false && (cmd.getFlags() & GET_SCORES) != 0) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
populateNextCursorMarkFromTopDocs(qr, cmd, topDocs);
maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f;
nDocsReturned = topDocs.scoreDocs.length;

View File

@ -17,6 +17,7 @@
package org.apache.solr.search.grouping;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Sort;
import java.io.IOException;
import java.util.List;
@ -39,6 +40,12 @@ public interface Command<T> {
*/
List<Collector> create() throws IOException;
/**
* Run post-collection steps.
* @throws IOException If I/O related errors occur
*/
default void postCollect(IndexSearcher searcher) throws IOException {}
/**
* Returns the results that the collectors created
* by {@link #create()} contain after a search has been executed.

View File

@ -163,6 +163,10 @@ public class CommandHandler {
} else {
searchWithTimeLimiter(query, filter, null);
}
for (Command command : commands) {
command.postCollect(searcher);
}
}
private DocSet computeGroupedDocSet(Query query, ProcessedFilter filter, List<Collector> collectors) throws IOException {

View File

@ -114,6 +114,7 @@ public class QueryCommand implements Command<QueryCommandResult> {
private TopDocsCollector topDocsCollector;
private FilterCollector filterCollector;
private MaxScoreCollector maxScoreCollector;
private TopDocs topDocs;
private QueryCommand(Sort sort, Query query, int docsToCollect, boolean needScores, DocSet docSet, String queryString) {
this.sort = sort;
@ -130,7 +131,7 @@ public class QueryCommand implements Command<QueryCommandResult> {
if (sort == null || sort.equals(Sort.RELEVANCE)) {
subCollector = topDocsCollector = TopScoreDocCollector.create(docsToCollect);
} else {
topDocsCollector = TopFieldCollector.create(sort, docsToCollect, needScores, true);
topDocsCollector = TopFieldCollector.create(sort, docsToCollect, true);
if (needScores) {
maxScoreCollector = new MaxScoreCollector();
subCollector = MultiCollector.wrap(topDocsCollector, maxScoreCollector);
@ -143,8 +144,15 @@ public class QueryCommand implements Command<QueryCommandResult> {
}
@Override
public QueryCommandResult result() {
TopDocs topDocs = topDocsCollector.topDocs();
public void postCollect(IndexSearcher searcher) throws IOException {
topDocs = topDocsCollector.topDocs();
if (needScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
}
@Override
public QueryCommandResult result() throws IOException {
float maxScore;
if (sort == null) {
maxScore = topDocs.scoreDocs.length == 0 ? Float.NaN : topDocs.scoreDocs[0].score;

View File

@ -25,7 +25,10 @@ import java.util.List;
import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.SearchGroup;
import org.apache.lucene.search.grouping.TermGroupSelector;
@ -45,6 +48,7 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
public static class Builder {
private Query query;
private SchemaField field;
private Sort groupSort;
private Sort withinGroupSort;
@ -53,6 +57,11 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
private boolean needScores = false;
private boolean needMaxScore = false;
public Builder setQuery(Query query) {
this.query = query;
return this;
}
public Builder setField(SchemaField field) {
this.field = field;
return this;
@ -89,16 +98,17 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
}
public TopGroupsFieldCommand build() {
if (field == null || groupSort == null || withinGroupSort == null || firstPhaseGroups == null ||
if (query == null || field == null || groupSort == null || withinGroupSort == null || firstPhaseGroups == null ||
maxDocPerGroup == null) {
throw new IllegalStateException("All required fields must be set");
}
return new TopGroupsFieldCommand(field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore);
return new TopGroupsFieldCommand(query, field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore);
}
}
private final Query query;
private final SchemaField field;
private final Sort groupSort;
private final Sort withinGroupSort;
@ -107,14 +117,17 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
private final boolean needScores;
private final boolean needMaxScore;
private TopGroupsCollector secondPassCollector;
private TopGroups<BytesRef> topGroups;
private TopGroupsFieldCommand(SchemaField field,
private TopGroupsFieldCommand(Query query,
SchemaField field,
Sort groupSort,
Sort withinGroupSort,
Collection<SearchGroup<BytesRef>> firstPhaseGroups,
int maxDocPerGroup,
boolean needScores,
boolean needMaxScore) {
this.query = query;
this.field = field;
this.groupSort = groupSort;
this.withinGroupSort = withinGroupSort;
@ -136,11 +149,11 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
ValueSource vs = fieldType.getValueSource(field, null);
Collection<SearchGroup<MutableValue>> v = GroupConverter.toMutable(field, firstPhaseGroups);
secondPassCollector = new TopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()),
v, groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore
v, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore
);
} else {
secondPassCollector = new TopGroupsCollector<>(new TermGroupSelector(field.getName()),
firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore
firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore
);
}
collectors.add(secondPassCollector);
@ -148,18 +161,27 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
}
@Override
@SuppressWarnings("unchecked")
public TopGroups<BytesRef> result() {
public void postCollect(IndexSearcher searcher) throws IOException {
if (firstPhaseGroups.isEmpty()) {
return new TopGroups<>(groupSort.getSort(), withinGroupSort.getSort(), 0, 0, new GroupDocs[0], Float.NaN);
topGroups = new TopGroups<>(groupSort.getSort(), withinGroupSort.getSort(), 0, 0, new GroupDocs[0], Float.NaN);
return;
}
FieldType fieldType = field.getType();
if (fieldType.getNumberType() != null) {
return GroupConverter.fromMutable(field, secondPassCollector.getTopGroups(0));
topGroups = GroupConverter.fromMutable(field, secondPassCollector.getTopGroups(0));
} else {
return secondPassCollector.getTopGroups(0);
topGroups = secondPassCollector.getTopGroups(0);
}
for (GroupDocs<?> group : topGroups.groups) {
TopFieldCollector.populateScores(group.scoreDocs, searcher, query);
}
}
@Override
@SuppressWarnings("unchecked")
public TopGroups<BytesRef> result() throws IOException {
return topGroups;
}
@Override

View File

@ -282,9 +282,8 @@ public class TestSort extends SolrTestCaseJ4 {
final String nullRep = luceneSort || sortMissingFirst && !reverse || sortMissingLast && reverse ? "" : "zzz";
final String nullRep2 = luceneSort2 || sortMissingFirst2 && !reverse2 || sortMissingLast2 && reverse2 ? "" : "zzz";
boolean trackScores = r.nextBoolean();
boolean scoreInOrder = r.nextBoolean();
final TopFieldCollector topCollector = TopFieldCollector.create(sort, top, trackScores, true);
final TopFieldCollector topCollector = TopFieldCollector.create(sort, top, true);
final List<MyDoc> collectedDocs = new ArrayList<>();
// delegate and collect docs ourselves

View File

@ -168,9 +168,9 @@ public class TestFieldCacheSortRandom extends LuceneTestCase {
int queryType = random.nextInt(2);
if (queryType == 0) {
hits = s.search(new ConstantScoreQuery(f),
hitCount, sort, random.nextBoolean());
hitCount, sort, false);
} else {
hits = s.search(f, hitCount, sort, random.nextBoolean());
hits = s.search(f, hitCount, sort, false);
}
if (VERBOSE) {