LUCENE-8412: TopFieldCollector no longer takes a trackDocScores option.

This commit is contained in:
Adrien Grand 2018-07-23 09:05:02 +02:00
parent 34686c00dd
commit 55bfadbce1
32 changed files with 346 additions and 264 deletions

View File

@ -60,9 +60,13 @@ API Changes
no longer have an option to compute the maximum score when sorting by field. no longer have an option to compute the maximum score when sorting by field.
(Adrien Grand) (Adrien Grand)
* LUCENE-8411: TopFieldCollector no longer takes a fillFields options, it now * LUCENE-8411: TopFieldCollector no longer takes a fillFields option, it now
always fills fields. (Adrien Grand) always fills fields. (Adrien Grand)
* LUCENE-8412: TopFieldCollector no longer takes a trackDocScores option. Scores
need to be set on top hits via TopFieldCollector#populateScores instead.
(Adrien Grand)
Changes in Runtime Behavior Changes in Runtime Behavior
* LUCENE-8333: Switch MoreLikeThis.setMaxDocFreqPct to use maxDoc instead of * LUCENE-8333: Switch MoreLikeThis.setMaxDocFreqPct to use maxDoc instead of

View File

@ -82,3 +82,10 @@ all matches.
Because filling sort values doesn't have a significant overhead, the fillFields Because filling sort values doesn't have a significant overhead, the fillFields
option has been removed from TopFieldCollector factory methods. Everything option has been removed from TopFieldCollector factory methods. Everything
behaves as if it was set to true. behaves as if it was set to true.
## TopFieldCollector no longer takes a trackDocScores option ##
Computing scores at collection time is less efficient than running a second
request in order to only compute scores for documents that made it to the top
hits. As a consequence, the trackDocScores option has been removed and can be
replaced with the new TopFieldCollector#populateScores helper method.

View File

@ -112,7 +112,6 @@ public abstract class ReadTask extends PerfTask {
// Weight public again, we can go back to // Weight public again, we can go back to
// pulling the Weight ourselves: // pulling the Weight ourselves:
TopFieldCollector collector = TopFieldCollector.create(sort, numHits, TopFieldCollector collector = TopFieldCollector.create(sort, numHits,
withScore(),
withTotalHits()); withTotalHits());
searcher.search(q, collector); searcher.search(q, collector);
hits = collector.topDocs(); hits = collector.topDocs();
@ -208,12 +207,6 @@ public abstract class ReadTask extends PerfTask {
*/ */
public abstract boolean withTraverse(); public abstract boolean withTraverse();
/** Whether scores should be computed (only useful with
* field sort) */
public boolean withScore() {
return true;
}
/** Whether totalHits should be computed (only useful with /** Whether totalHits should be computed (only useful with
* field sort) */ * field sort) */
public boolean withTotalHits() { public boolean withTotalHits() {

View File

@ -29,7 +29,6 @@ import org.apache.lucene.search.SortField;
*/ */
public class SearchWithSortTask extends ReadTask { public class SearchWithSortTask extends ReadTask {
private boolean doScore = true;
private Sort sort; private Sort sort;
public SearchWithSortTask(PerfRunData runData) { public SearchWithSortTask(PerfRunData runData) {
@ -60,9 +59,6 @@ public class SearchWithSortTask extends ReadTask {
sortField0 = SortField.FIELD_DOC; sortField0 = SortField.FIELD_DOC;
} else if (field.equals("score")) { } else if (field.equals("score")) {
sortField0 = SortField.FIELD_SCORE; sortField0 = SortField.FIELD_SCORE;
} else if (field.equals("noscore")) {
doScore = false;
continue;
} else { } else {
int index = field.lastIndexOf(":"); int index = field.lastIndexOf(":");
String fieldName; String fieldName;
@ -116,11 +112,6 @@ public class SearchWithSortTask extends ReadTask {
return false; return false;
} }
@Override
public boolean withScore() {
return doScore;
}
@Override @Override
public Sort getSort() { public Sort getSort() {
if (sort == null) { if (sort == null) {

View File

@ -513,7 +513,7 @@ public class IndexSearcher {
@Override @Override
public TopFieldCollector newCollector() throws IOException { public TopFieldCollector newCollector() throws IOException {
// TODO: don't pay the price for accurate hit counts by default // TODO: don't pay the price for accurate hit counts by default
return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, doDocScores, true); return TopFieldCollector.create(rewrittenSort, cappedNumHits, after, true);
} }
@Override @Override
@ -528,7 +528,11 @@ public class IndexSearcher {
}; };
return search(query, manager); TopFieldDocs topDocs = search(query, manager);
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
return topDocs;
} }
/** /**

View File

@ -44,17 +44,12 @@ public class SortRescorer extends Rescorer {
// Copy ScoreDoc[] and sort by ascending docID: // Copy ScoreDoc[] and sort by ascending docID:
ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone(); ScoreDoc[] hits = firstPassTopDocs.scoreDocs.clone();
Arrays.sort(hits, Comparator<ScoreDoc> docIdComparator = Comparator.comparingInt(sd -> sd.doc);
new Comparator<ScoreDoc>() { Arrays.sort(hits, docIdComparator);
@Override
public int compare(ScoreDoc a, ScoreDoc b) {
return a.doc - b.doc;
}
});
List<LeafReaderContext> leaves = searcher.getIndexReader().leaves(); List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
TopFieldCollector collector = TopFieldCollector.create(sort, topN, true, true); TopFieldCollector collector = TopFieldCollector.create(sort, topN, true);
// Now merge sort docIDs from hits, with reader's leaves: // Now merge sort docIDs from hits, with reader's leaves:
int hitUpto = 0; int hitUpto = 0;
@ -90,7 +85,15 @@ public class SortRescorer extends Rescorer {
hitUpto++; hitUpto++;
} }
return collector.topDocs(); TopDocs rescoredDocs = collector.topDocs();
// set scores from the original score docs
assert hits.length == rescoredDocs.scoreDocs.length;
ScoreDoc[] rescoredDocsClone = rescoredDocs.scoreDocs.clone();
Arrays.sort(rescoredDocsClone, docIdComparator);
for (int i = 0; i < rescoredDocsClone.length; ++i) {
rescoredDocsClone[i].score = hits[i].score;
}
return rescoredDocs;
} }
@Override @Override

View File

@ -19,16 +19,20 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.ReaderUtil;
import org.apache.lucene.search.FieldValueHitQueue.Entry; import org.apache.lucene.search.FieldValueHitQueue.Entry;
import org.apache.lucene.util.FutureObjects;
import org.apache.lucene.util.PriorityQueue; import org.apache.lucene.util.PriorityQueue;
/** /**
* A {@link Collector} that sorts by {@link SortField} using * A {@link Collector} that sorts by {@link SortField} using
* {@link FieldComparator}s. * {@link FieldComparator}s.
* <p> * <p>
* See the {@link #create(org.apache.lucene.search.Sort, int, boolean, boolean)} method * See the {@link #create(org.apache.lucene.search.Sort, int, boolean)} method
* for instantiating a TopFieldCollector. * for instantiating a TopFieldCollector.
* *
* @lucene.experimental * @lucene.experimental
@ -44,10 +48,9 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final LeafFieldComparator comparator; final LeafFieldComparator comparator;
final int reverseMul; final int reverseMul;
final boolean mayNeedScoresTwice;
Scorer scorer; Scorer scorer;
MultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul, boolean mayNeedScoresTwice) { MultiComparatorLeafCollector(LeafFieldComparator[] comparators, int[] reverseMul) {
if (comparators.length == 1) { if (comparators.length == 1) {
this.reverseMul = reverseMul[0]; this.reverseMul = reverseMul[0];
this.comparator = comparators[0]; this.comparator = comparators[0];
@ -55,14 +58,10 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
this.reverseMul = 1; this.reverseMul = 1;
this.comparator = new MultiLeafFieldComparator(comparators, reverseMul); this.comparator = new MultiLeafFieldComparator(comparators, reverseMul);
} }
this.mayNeedScoresTwice = mayNeedScoresTwice;
} }
@Override @Override
public void setScorer(Scorer scorer) throws IOException { public void setScorer(Scorer scorer) throws IOException {
if (mayNeedScoresTwice && scorer instanceof ScoreCachingWrappingScorer == false) {
scorer = new ScoreCachingWrappingScorer(scorer);
}
comparator.setScorer(scorer); comparator.setScorer(scorer);
this.scorer = scorer; this.scorer = scorer;
} }
@ -93,20 +92,12 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final Sort sort; final Sort sort;
final FieldValueHitQueue<Entry> queue; final FieldValueHitQueue<Entry> queue;
final boolean trackDocScores;
final boolean mayNeedScoresTwice;
final boolean trackTotalHits; final boolean trackTotalHits;
public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits, public SimpleFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, int numHits, boolean trackTotalHits) {
boolean trackDocScores, boolean trackTotalHits) { super(queue, numHits, sort.needsScores());
super(queue, numHits, sort.needsScores() || trackDocScores);
this.sort = sort; this.sort = sort;
this.queue = queue; this.queue = queue;
this.trackDocScores = trackDocScores;
// If one of the sort fields needs scores, and if we also track scores, then
// we might call scorer.score() several times per doc so wrapping the scorer
// to cache scores would help
this.mayNeedScoresTwice = sort.needsScores() && trackDocScores;
this.trackTotalHits = trackTotalHits; this.trackTotalHits = trackTotalHits;
} }
@ -122,7 +113,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
canEarlyTerminate(sort, indexSort); canEarlyTerminate(sort, indexSort);
final int initialTotalHits = totalHits; final int initialTotalHits = totalHits;
return new MultiComparatorLeafCollector(comparators, reverseMul, mayNeedScoresTwice) { return new MultiComparatorLeafCollector(comparators, reverseMul) {
@Override @Override
public void collect(int doc) throws IOException { public void collect(int doc) throws IOException {
@ -146,10 +137,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
} }
} }
if (trackDocScores) {
score = scorer.score();
}
// This hit is competitive - replace bottom element in queue & adjustTop // This hit is competitive - replace bottom element in queue & adjustTop
comparator.copy(bottom.slot, doc); comparator.copy(bottom.slot, doc);
updateBottom(doc, score); updateBottom(doc, score);
@ -158,10 +145,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// Startup transient: queue hasn't gathered numHits yet // Startup transient: queue hasn't gathered numHits yet
final int slot = totalHits - 1; final int slot = totalHits - 1;
if (trackDocScores) {
score = scorer.score();
}
// Copy hit into queue // Copy hit into queue
comparator.copy(slot, doc); comparator.copy(slot, doc);
add(slot, doc, score); add(slot, doc, score);
@ -184,19 +167,15 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
final Sort sort; final Sort sort;
int collectedHits; int collectedHits;
final FieldValueHitQueue<Entry> queue; final FieldValueHitQueue<Entry> queue;
final boolean trackDocScores;
final FieldDoc after; final FieldDoc after;
final boolean mayNeedScoresTwice;
final boolean trackTotalHits; final boolean trackTotalHits;
public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits, public PagingFieldCollector(Sort sort, FieldValueHitQueue<Entry> queue, FieldDoc after, int numHits,
boolean trackDocScores, boolean trackTotalHits) { boolean trackTotalHits) {
super(queue, numHits, trackDocScores || sort.needsScores()); super(queue, numHits, sort.needsScores());
this.sort = sort; this.sort = sort;
this.queue = queue; this.queue = queue;
this.trackDocScores = trackDocScores;
this.after = after; this.after = after;
this.mayNeedScoresTwice = sort.needsScores() && trackDocScores;
this.trackTotalHits = trackTotalHits; this.trackTotalHits = trackTotalHits;
FieldComparator<?>[] comparators = queue.comparators; FieldComparator<?>[] comparators = queue.comparators;
@ -217,7 +196,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
indexSort != null && indexSort != null &&
canEarlyTerminate(sort, indexSort); canEarlyTerminate(sort, indexSort);
final int initialTotalHits = totalHits; final int initialTotalHits = totalHits;
return new MultiComparatorLeafCollector(queue.getComparators(context), queue.getReverseMul(), mayNeedScoresTwice) { return new MultiComparatorLeafCollector(queue.getComparators(context), queue.getReverseMul()) {
@Override @Override
public void collect(int doc) throws IOException { public void collect(int doc) throws IOException {
@ -256,10 +235,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// This hit is competitive - replace bottom element in queue & adjustTop // This hit is competitive - replace bottom element in queue & adjustTop
comparator.copy(bottom.slot, doc); comparator.copy(bottom.slot, doc);
// Compute score only if it is competitive.
if (trackDocScores) {
score = scorer.score();
}
updateBottom(doc, score); updateBottom(doc, score);
comparator.setBottom(bottom.slot); comparator.setBottom(bottom.slot);
@ -272,10 +247,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
// Copy hit into queue // Copy hit into queue
comparator.copy(slot, doc); comparator.copy(slot, doc);
// Compute score only if it is competitive.
if (trackDocScores) {
score = scorer.score();
}
bottom = pq.add(new Entry(slot, docBase + doc, score)); bottom = pq.add(new Entry(slot, docBase + doc, score));
queueFull = collectedHits == numHits; queueFull = collectedHits == numHits;
if (queueFull) { if (queueFull) {
@ -325,13 +296,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the sort criteria (SortFields). * the sort criteria (SortFields).
* @param numHits * @param numHits
* the number of results to collect. * the number of results to collect.
* @param trackDocScores
* specifies whether document scores should be tracked and set on the
* results. Note that if set to false, then the results' scores will
* be set to Float.NaN. Setting this to true affects performance, as
* it incurs the score computation on each competitive result.
* Therefore if document scores are not required by the application,
* it is recommended to set it to false.
* @param trackTotalHits * @param trackTotalHits
* specifies whether the total number of hits should be tracked. If * specifies whether the total number of hits should be tracked. If
* set to false, the value of {@link TopFieldDocs#totalHits} will be * set to false, the value of {@link TopFieldDocs#totalHits} will be
@ -339,9 +303,8 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* @return a {@link TopFieldCollector} instance which will sort the results by * @return a {@link TopFieldCollector} instance which will sort the results by
* the sort criteria. * the sort criteria.
*/ */
public static TopFieldCollector create(Sort sort, int numHits, public static TopFieldCollector create(Sort sort, int numHits, boolean trackTotalHits) {
boolean trackDocScores, boolean trackTotalHits) { return create(sort, numHits, null, trackTotalHits);
return create(sort, numHits, null, trackDocScores, trackTotalHits);
} }
/** /**
@ -358,14 +321,6 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the number of results to collect. * the number of results to collect.
* @param after * @param after
* only hits after this FieldDoc will be collected * only hits after this FieldDoc will be collected
* @param trackDocScores
* specifies whether document scores should be tracked and set on the
* results. Note that if set to false, then the results' scores will
* be set to Float.NaN. Setting this to true affects performance, as
* it incurs the score computation on each competitive result.
* Therefore if document scores are not required by the application,
* it is recommended to set it to false.
* <code>trackDocScores</code> to true as well.
* @param trackTotalHits * @param trackTotalHits
* specifies whether the total number of hits should be tracked. If * specifies whether the total number of hits should be tracked. If
* set to false, the value of {@link TopFieldDocs#totalHits} will be * set to false, the value of {@link TopFieldDocs#totalHits} will be
@ -374,7 +329,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
* the sort criteria. * the sort criteria.
*/ */
public static TopFieldCollector create(Sort sort, int numHits, FieldDoc after, public static TopFieldCollector create(Sort sort, int numHits, FieldDoc after,
boolean trackDocScores, boolean trackTotalHits) { boolean trackTotalHits) {
if (sort.fields.length == 0) { if (sort.fields.length == 0) {
throw new IllegalArgumentException("Sort must contain at least one field"); throw new IllegalArgumentException("Sort must contain at least one field");
@ -387,7 +342,7 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits); FieldValueHitQueue<Entry> queue = FieldValueHitQueue.create(sort.fields, numHits);
if (after == null) { if (after == null) {
return new SimpleFieldCollector(sort, queue, numHits, trackDocScores, trackTotalHits); return new SimpleFieldCollector(sort, queue, numHits, trackTotalHits);
} else { } else {
if (after.fields == null) { if (after.fields == null) {
throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search"); throw new IllegalArgumentException("after.fields wasn't set; you must pass fillFields=true for the previous search");
@ -397,7 +352,46 @@ public abstract class TopFieldCollector extends TopDocsCollector<Entry> {
throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length); throw new IllegalArgumentException("after.fields has " + after.fields.length + " values but sort has " + sort.getSort().length);
} }
return new PagingFieldCollector(sort, queue, after, numHits, trackDocScores, trackTotalHits); return new PagingFieldCollector(sort, queue, after, numHits, trackTotalHits);
}
}
/**
* Populate {@link ScoreDoc#score scores} of the given {@code topDocs}.
* @param topDocs the top docs to populate
* @param searcher the index searcher that has been used to compute {@code topDocs}
* @param query the query that has been used to compute {@code topDocs}
* @throws IllegalArgumentException if there is evidence that {@code topDocs}
* have been computed against a different searcher or a different query.
* @lucene.experimental
*/
public static void populateScores(ScoreDoc[] topDocs, IndexSearcher searcher, Query query) throws IOException {
// Get the score docs sorted in doc id order
topDocs = topDocs.clone();
Arrays.sort(topDocs, Comparator.comparingInt(scoreDoc -> scoreDoc.doc));
final Weight weight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE, 1);
List<LeafReaderContext> contexts = searcher.getIndexReader().leaves();
LeafReaderContext currentContext = null;
Scorer currentScorer = null;
for (ScoreDoc scoreDoc : topDocs) {
if (currentContext == null || scoreDoc.doc >= currentContext.docBase + currentContext.reader().maxDoc()) {
FutureObjects.checkIndex(scoreDoc.doc, searcher.getIndexReader().maxDoc());
int newContextIndex = ReaderUtil.subIndex(scoreDoc.doc, contexts);
currentContext = contexts.get(newContextIndex);
final ScorerSupplier scorerSupplier = weight.scorerSupplier(currentContext);
if (scorerSupplier == null) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
currentScorer = scorerSupplier.get(1); // random-access
}
final int leafDoc = scoreDoc.doc - currentContext.docBase;
assert leafDoc >= 0;
final int advanced = currentScorer.iterator().advance(leafDoc);
if (leafDoc != advanced) {
throw new IllegalArgumentException("Doc id " + scoreDoc.doc + " doesn't match the query");
}
scoreDoc.score = currentScorer.score();
} }
} }

View File

@ -2324,11 +2324,11 @@ public class TestIndexSorting extends LuceneTestCase {
System.out.println("TEST: iter=" + iter + " numHits=" + numHits); System.out.println("TEST: iter=" + iter + " numHits=" + numHits);
} }
TopFieldCollector c1 = TopFieldCollector.create(sort, numHits, true, true); TopFieldCollector c1 = TopFieldCollector.create(sort, numHits, true);
s1.search(new MatchAllDocsQuery(), c1); s1.search(new MatchAllDocsQuery(), c1);
TopDocs hits1 = c1.topDocs(); TopDocs hits1 = c1.topDocs();
TopFieldCollector c2 = TopFieldCollector.create(sort, numHits, true, false); TopFieldCollector c2 = TopFieldCollector.create(sort, numHits, false);
s2.search(new MatchAllDocsQuery(), c2); s2.search(new MatchAllDocsQuery(), c2);
TopDocs hits2 = c2.topDocs(); TopDocs hits2 = c2.topDocs();

View File

@ -386,10 +386,10 @@ public class TestBoolean2 extends LuceneTestCase {
} }
// check diff (randomized) scorers (from AssertingSearcher) produce the same results // check diff (randomized) scorers (from AssertingSearcher) produce the same results
TopFieldCollector collector = TopFieldCollector.create(sort, 1000, true, false); TopFieldCollector collector = TopFieldCollector.create(sort, 1000, false);
searcher.search(q1, collector); searcher.search(q1, collector);
ScoreDoc[] hits1 = collector.topDocs().scoreDocs; ScoreDoc[] hits1 = collector.topDocs().scoreDocs;
collector = TopFieldCollector.create(sort, 1000, true, false); collector = TopFieldCollector.create(sort, 1000, false);
searcher.search(q1, collector); searcher.search(q1, collector);
ScoreDoc[] hits2 = collector.topDocs().scoreDocs; ScoreDoc[] hits2 = collector.topDocs().scoreDocs;
tot+=hits2.length; tot+=hits2.length;
@ -402,10 +402,10 @@ public class TestBoolean2 extends LuceneTestCase {
assertEquals(mulFactor*collector.totalHits + NUM_EXTRA_DOCS/2, hits4.totalHits); assertEquals(mulFactor*collector.totalHits + NUM_EXTRA_DOCS/2, hits4.totalHits);
// test diff (randomized) scorers produce the same results on bigSearcher as well // test diff (randomized) scorers produce the same results on bigSearcher as well
collector = TopFieldCollector.create(sort, 1000 * mulFactor, true, false); collector = TopFieldCollector.create(sort, 1000 * mulFactor, false);
bigSearcher.search(q1, collector); bigSearcher.search(q1, collector);
hits1 = collector.topDocs().scoreDocs; hits1 = collector.topDocs().scoreDocs;
collector = TopFieldCollector.create(sort, 1000 * mulFactor, true, false); collector = TopFieldCollector.create(sort, 1000 * mulFactor, false);
bigSearcher.search(q1, collector); bigSearcher.search(q1, collector);
hits2 = collector.topDocs().scoreDocs; hits2 = collector.topDocs().scoreDocs;
CheckHits.checkEqual(q1, hits1, hits2); CheckHits.checkEqual(q1, hits1, hits2);

View File

@ -86,7 +86,7 @@ public class TestElevationComparator extends LuceneTestCase {
new SortField(null, SortField.Type.SCORE, reversed) new SortField(null, SortField.Type.SCORE, reversed)
); );
TopDocsCollector<Entry> topCollector = TopFieldCollector.create(sort, 50, true, true); TopDocsCollector<Entry> topCollector = TopFieldCollector.create(sort, 50, true);
searcher.search(newq.build(), topCollector); searcher.search(newq.build(), topCollector);
TopDocs topDocs = topCollector.topDocs(0, 10); TopDocs topDocs = topCollector.topDocs(0, 10);

View File

@ -146,7 +146,7 @@ public class TestSortRandom extends LuceneTestCase {
} }
final int hitCount = TestUtil.nextInt(random, 1, r.maxDoc() + 20); final int hitCount = TestUtil.nextInt(random, 1, r.maxDoc() + 20);
final RandomQuery f = new RandomQuery(random.nextLong(), random.nextFloat(), docValues); final RandomQuery f = new RandomQuery(random.nextLong(), random.nextFloat(), docValues);
hits = s.search(f, hitCount, sort, random.nextBoolean()); hits = s.search(f, hitCount, sort, false);
if (VERBOSE) { if (VERBOSE) {
System.out.println("\nTEST: iter=" + iter + " " + hits.totalHits + " hits; topN=" + hitCount + "; reverse=" + reverse + "; sortMissingLast=" + sortMissingLast + " sort=" + sort); System.out.println("\nTEST: iter=" + iter + " " + hits.totalHits + " hits; topN=" + hitCount + "; reverse=" + reverse + "; sortMissingLast=" + sortMissingLast + " sort=" + sort);

View File

@ -281,7 +281,7 @@ public class TestTopDocsMerge extends LuceneTestCase {
topHits = searcher.search(query, numHits); topHits = searcher.search(query, numHits);
} }
} else { } else {
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true, true); final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true);
searcher.search(query, c); searcher.search(query, c);
if (useFrom) { if (useFrom) {
from = TestUtil.nextInt(random(), 0, numHits - 1); from = TestUtil.nextInt(random(), 0, numHits - 1);
@ -330,7 +330,7 @@ public class TestTopDocsMerge extends LuceneTestCase {
if (sort == null) { if (sort == null) {
subHits = subSearcher.search(w, numHits); subHits = subSearcher.search(w, numHits);
} else { } else {
final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true, true); final TopFieldCollector c = TopFieldCollector.create(sort, numHits, true);
subSearcher.search(w, c); subSearcher.search(w, c);
subHits = c.topDocs(0, numHits); subHits = c.topDocs(0, numHits);
} }

View File

@ -18,11 +18,14 @@ package org.apache.lucene.search;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Comparator;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.RandomIndexWriter;
@ -70,8 +73,7 @@ public class TestTopFieldCollector extends LuceneTestCase {
Sort[] sort = new Sort[] { new Sort(SortField.FIELD_DOC), new Sort() }; Sort[] sort = new Sort[] { new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) { for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery(); Query q = new MatchAllDocsQuery();
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
false, true);
is.search(q, tdc); is.search(q, tdc);
@ -83,14 +85,13 @@ public class TestTopFieldCollector extends LuceneTestCase {
} }
} }
public void testSortWithoutScoreTracking() throws Exception { public void testSort() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators. // Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() }; Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) { for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery(); Query q = new MatchAllDocsQuery();
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, false, TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
true);
is.search(q, tdc); is.search(q, tdc);
@ -110,10 +111,10 @@ public class TestTopFieldCollector extends LuceneTestCase {
// the index is not sorted // the index is not sorted
TopDocsCollector<Entry> tdc; TopDocsCollector<Entry> tdc;
if (i % 2 == 0) { if (i % 2 == 0) {
tdc = TopFieldCollector.create(sort, 10, false, false); tdc = TopFieldCollector.create(sort, 10, false);
} else { } else {
FieldDoc fieldDoc = new FieldDoc(1, Float.NaN, new Object[] { 1 }); FieldDoc fieldDoc = new FieldDoc(1, Float.NaN, new Object[] { 1 });
tdc = TopFieldCollector.create(sort, 10, fieldDoc, false, false); tdc = TopFieldCollector.create(sort, 10, fieldDoc, false);
} }
is.search(q, tdc); is.search(q, tdc);
@ -126,31 +127,12 @@ public class TestTopFieldCollector extends LuceneTestCase {
} }
} }
public void testSortWithScoreTracking() throws Exception { public void testSortNoResults() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators. // Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() }; Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) { for(int i = 0; i < sort.length; i++) {
Query q = new MatchAllDocsQuery(); TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true);
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true,
true);
is.search(q, tdc);
TopDocs td = tdc.topDocs();
ScoreDoc[] sd = td.scoreDocs;
for(int j = 0; j < sd.length; j++) {
assertTrue(!Float.isNaN(sd[j].score));
}
}
}
public void testSortWithScoreTrackingNoResults() throws Exception {
// Two Sort criteria to instantiate the multi/single comparators.
Sort[] sort = new Sort[] {new Sort(SortField.FIELD_DOC), new Sort() };
for(int i = 0; i < sort.length; i++) {
TopDocsCollector<Entry> tdc = TopFieldCollector.create(sort[i], 10, true, true);
TopDocs td = tdc.topDocs(); TopDocs td = tdc.topDocs();
assertEquals(0, td.totalHits); assertEquals(0, td.totalHits);
} }
@ -182,59 +164,112 @@ public class TestTopFieldCollector extends LuceneTestCase {
.build(); .build();
final IndexSearcher searcher = new IndexSearcher(reader); final IndexSearcher searcher = new IndexSearcher(reader);
for (Sort sort : new Sort[] {new Sort(SortField.FIELD_SCORE), new Sort(new SortField("f", SortField.Type.SCORE))}) { for (Sort sort : new Sort[] {new Sort(SortField.FIELD_SCORE), new Sort(new SortField("f", SortField.Type.SCORE))}) {
for (boolean doDocScores : new boolean[] {false, true}) { final TopFieldCollector topCollector = TopFieldCollector.create(sort, TestUtil.nextInt(random(), 1, 2), true);
final TopFieldCollector topCollector = TopFieldCollector.create(sort, TestUtil.nextInt(random(), 1, 2), doDocScores, true); final Collector assertingCollector = new Collector() {
final Collector assertingCollector = new Collector() { @Override
@Override public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { final LeafCollector in = topCollector.getLeafCollector(context);
final LeafCollector in = topCollector.getLeafCollector(context); return new FilterLeafCollector(in) {
return new FilterLeafCollector(in) { @Override
@Override public void setScorer(final Scorer scorer) throws IOException {
public void setScorer(final Scorer scorer) throws IOException { Scorer s = new Scorer(null) {
Scorer s = new Scorer(null) {
int lastComputedDoc = -1; int lastComputedDoc = -1;
@Override @Override
public float score() throws IOException { public float score() throws IOException {
if (lastComputedDoc == docID()) { if (lastComputedDoc == docID()) {
throw new AssertionError("Score computed twice on " + docID()); throw new AssertionError("Score computed twice on " + docID());
}
lastComputedDoc = docID();
return scorer.score();
} }
lastComputedDoc = docID();
return scorer.score();
}
@Override @Override
public float getMaxScore(int upTo) throws IOException { public float getMaxScore(int upTo) throws IOException {
return scorer.getMaxScore(upTo); return scorer.getMaxScore(upTo);
} }
@Override @Override
public int docID() { public int docID() {
return scorer.docID(); return scorer.docID();
} }
@Override @Override
public DocIdSetIterator iterator() { public DocIdSetIterator iterator() {
return scorer.iterator(); return scorer.iterator();
} }
}; };
super.setScorer(s); super.setScorer(s);
} }
}; };
} }
@Override @Override
public ScoreMode scoreMode() { public ScoreMode scoreMode() {
return topCollector.scoreMode(); return topCollector.scoreMode();
} }
}; };
searcher.search(query, assertingCollector); searcher.search(query, assertingCollector);
}
} }
reader.close(); reader.close();
w.close(); w.close();
dir.close(); dir.close();
} }
public void testPopulateScores() throws IOException {
Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
Document doc = new Document();
TextField field = new TextField("f", "foo bar", Store.NO);
doc.add(field);
NumericDocValuesField sortField = new NumericDocValuesField("sort", 0);
doc.add(sortField);
w.addDocument(doc);
field.setStringValue("");
sortField.setLongValue(3);
w.addDocument(doc);
field.setStringValue("foo foo bar");
sortField.setLongValue(2);
w.addDocument(doc);
w.flush();
field.setStringValue("foo");
sortField.setLongValue(2);
w.addDocument(doc);
field.setStringValue("bar bar bar");
sortField.setLongValue(0);
w.addDocument(doc);
IndexReader reader = w.getReader();
w.close();
IndexSearcher searcher = newSearcher(reader);
for (String queryText : new String[] { "foo", "bar" }) {
Query query = new TermQuery(new Term("f", queryText));
for (boolean reverse : new boolean[] {false, true}) {
ScoreDoc[] sortedByDoc = searcher.search(query, 10).scoreDocs;
Arrays.sort(sortedByDoc, Comparator.comparingInt(sd -> sd.doc));
Sort sort = new Sort(new SortField("sort", SortField.Type.LONG, reverse));
ScoreDoc[] sortedByField = searcher.search(query, 10, sort).scoreDocs;
ScoreDoc[] sortedByFieldClone = sortedByField.clone();
TopFieldCollector.populateScores(sortedByFieldClone, searcher, query);
for (int i = 0; i < sortedByFieldClone.length; ++i) {
assertEquals(sortedByFieldClone[i].doc, sortedByField[i].doc);
assertSame(((FieldDoc) sortedByFieldClone[i]).fields, ((FieldDoc) sortedByField[i]).fields);
assertEquals(sortedByFieldClone[i].score,
sortedByDoc[Arrays.binarySearch(sortedByDoc, sortedByFieldClone[i], Comparator.comparingInt(sd -> sd.doc))].score, 0f);
}
}
}
reader.close();
dir.close();
}
} }

View File

@ -136,9 +136,8 @@ public class TestTopFieldCollectorEarlyTermination extends LuceneTestCase {
} else { } else {
after = null; after = null;
} }
final boolean trackDocScores = random().nextBoolean(); final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, after, true);
final TopFieldCollector collector1 = TopFieldCollector.create(sort, numHits, after, trackDocScores, true); final TopFieldCollector collector2 = TopFieldCollector.create(sort, numHits, after, false);
final TopFieldCollector collector2 = TopFieldCollector.create(sort, numHits, after, trackDocScores, false);
final Query query; final Query query;
if (random().nextBoolean()) { if (random().nextBoolean()) {

View File

@ -241,7 +241,7 @@ public class DrillSideways {
@Override @Override
public TopFieldCollector newCollector() throws IOException { public TopFieldCollector newCollector() throws IOException {
return TopFieldCollector.create(sort, fTopN, after, doDocScores, true); return TopFieldCollector.create(sort, fTopN, after, true);
} }
@Override @Override
@ -255,14 +255,22 @@ public class DrillSideways {
}; };
ConcurrentDrillSidewaysResult<TopFieldDocs> r = search(query, collectorManager); ConcurrentDrillSidewaysResult<TopFieldDocs> r = search(query, collectorManager);
return new DrillSidewaysResult(r.facets, r.collectorResult); TopFieldDocs topDocs = r.collectorResult;
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
return new DrillSidewaysResult(r.facets, topDocs);
} else { } else {
final TopFieldCollector hitCollector = final TopFieldCollector hitCollector =
TopFieldCollector.create(sort, fTopN, after, doDocScores, true); TopFieldCollector.create(sort, fTopN, after, true);
DrillSidewaysResult r = search(query, hitCollector); DrillSidewaysResult r = search(query, hitCollector);
return new DrillSidewaysResult(r.facets, hitCollector.topDocs()); TopFieldDocs topDocs = hitCollector.topDocs();
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
return new DrillSidewaysResult(r.facets, topDocs);
} }
} else { } else {
return search(after, query, topN); return search(after, query, topN);

View File

@ -230,7 +230,6 @@ public class FacetsCollector extends SimpleCollector implements Collector {
} }
hitsCollector = TopFieldCollector.create(sort, n, hitsCollector = TopFieldCollector.create(sort, n,
(FieldDoc) after, (FieldDoc) after,
doDocScores,
true); // TODO: can we disable exact hit counts true); // TODO: can we disable exact hit counts
} else { } else {
hitsCollector = TopScoreDocCollector.create(n, after, true); hitsCollector = TopScoreDocCollector.create(n, after, true);
@ -238,6 +237,9 @@ public class FacetsCollector extends SimpleCollector implements Collector {
searcher.search(q, MultiCollector.wrap(hitsCollector, fc)); searcher.search(q, MultiCollector.wrap(hitsCollector, fc));
topDocs = hitsCollector.topDocs(); topDocs = hitsCollector.topDocs();
if (doDocScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, q);
}
} }
return topDocs; return topDocs;
} }

View File

@ -304,7 +304,7 @@ public class BlockGroupingCollector extends SimpleCollector {
collector = TopScoreDocCollector.create(maxDocsPerGroup); collector = TopScoreDocCollector.create(maxDocsPerGroup);
} else { } else {
// Sort by fields // Sort by fields
collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, needsScores, true); // TODO: disable exact counts? collector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, true); // TODO: disable exact counts?
} }
float groupMaxScore = needsScores ? Float.NEGATIVE_INFINITY : Float.NaN; float groupMaxScore = needsScores ? Float.NEGATIVE_INFINITY : Float.NaN;

View File

@ -50,7 +50,6 @@ public class GroupingSearch {
private int groupDocsOffset; private int groupDocsOffset;
private int groupDocsLimit = 1; private int groupDocsLimit = 1;
private boolean includeScores = true;
private boolean includeMaxScore = true; private boolean includeMaxScore = true;
private Double maxCacheRAMMB; private Double maxCacheRAMMB;
@ -153,8 +152,7 @@ public class GroupingSearch {
int topNInsideGroup = groupDocsOffset + groupDocsLimit; int topNInsideGroup = groupDocsOffset + groupDocsLimit;
TopGroupsCollector secondPassCollector TopGroupsCollector secondPassCollector
= new TopGroupsCollector(grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup, = new TopGroupsCollector(grouper, topSearchGroups, groupSort, sortWithinGroup, topNInsideGroup, includeMaxScore);
includeScores, includeMaxScore);
if (cachedCollector != null && cachedCollector.isCached()) { if (cachedCollector != null && cachedCollector.isCached()) {
cachedCollector.replay(secondPassCollector); cachedCollector.replay(secondPassCollector);
@ -173,7 +171,7 @@ public class GroupingSearch {
int topN = groupOffset + groupLimit; int topN = groupOffset + groupLimit;
final Query endDocsQuery = searcher.rewrite(this.groupEndDocs); final Query endDocsQuery = searcher.rewrite(this.groupEndDocs);
final Weight groupEndDocs = searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1); final Weight groupEndDocs = searcher.createWeight(endDocsQuery, ScoreMode.COMPLETE_NO_SCORES, 1);
BlockGroupingCollector c = new BlockGroupingCollector(groupSort, topN, includeScores, groupEndDocs); BlockGroupingCollector c = new BlockGroupingCollector(groupSort, topN, groupSort.needsScores() || sortWithinGroup.needsScores(), groupEndDocs);
searcher.search(query, c); searcher.search(query, c);
int topNInsideGroup = groupDocsOffset + groupDocsLimit; int topNInsideGroup = groupDocsOffset + groupDocsLimit;
return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup); return c.getTopGroups(sortWithinGroup, groupOffset, groupDocsOffset, topNInsideGroup);
@ -268,17 +266,6 @@ public class GroupingSearch {
return this; return this;
} }
/**
* Whether to include the scores per doc inside a group.
*
* @param includeScores Whether to include the scores per doc inside a group
* @return <code>this</code>
*/
public GroupingSearch setIncludeScores(boolean includeScores) {
this.includeScores = includeScores;
return this;
}
/** /**
* Whether to include the score of the most relevant document per group. * Whether to include the score of the most relevant document per group.
* *

View File

@ -54,13 +54,12 @@ public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> {
* @param groupSort the order in which groups are returned * @param groupSort the order in which groups are returned
* @param withinGroupSort the order in which documents are sorted in each group * @param withinGroupSort the order in which documents are sorted in each group
* @param maxDocsPerGroup the maximum number of docs to collect for each group * @param maxDocsPerGroup the maximum number of docs to collect for each group
* @param getScores if true, record the scores of all docs in each group
* @param getMaxScores if true, record the maximum score for each group * @param getMaxScores if true, record the maximum score for each group
*/ */
public TopGroupsCollector(GroupSelector<T> groupSelector, Collection<SearchGroup<T>> groups, Sort groupSort, Sort withinGroupSort, public TopGroupsCollector(GroupSelector<T> groupSelector, Collection<SearchGroup<T>> groups, Sort groupSort, Sort withinGroupSort,
int maxDocsPerGroup, boolean getScores, boolean getMaxScores) { int maxDocsPerGroup, boolean getMaxScores) {
super(groupSelector, groups, super(groupSelector, groups,
new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getScores, getMaxScores)); new TopDocsReducer<>(withinGroupSort, maxDocsPerGroup, getMaxScores));
this.groupSort = Objects.requireNonNull(groupSort); this.groupSort = Objects.requireNonNull(groupSort);
this.withinGroupSort = Objects.requireNonNull(withinGroupSort); this.withinGroupSort = Objects.requireNonNull(withinGroupSort);
this.maxDocsPerGroup = maxDocsPerGroup; this.maxDocsPerGroup = maxDocsPerGroup;
@ -114,13 +113,13 @@ public class TopGroupsCollector<T> extends SecondPassGroupingCollector<T> {
private final boolean needsScores; private final boolean needsScores;
TopDocsReducer(Sort withinGroupSort, TopDocsReducer(Sort withinGroupSort,
int maxDocsPerGroup, boolean getScores, boolean getMaxScores) { int maxDocsPerGroup, boolean getMaxScores) {
this.needsScores = getScores || getMaxScores || withinGroupSort.needsScores(); this.needsScores = getMaxScores || withinGroupSort.needsScores();
if (withinGroupSort == Sort.RELEVANCE) { if (withinGroupSort == Sort.RELEVANCE) {
supplier = () -> new TopDocsAndMaxScoreCollector(true, TopScoreDocCollector.create(maxDocsPerGroup), null); supplier = () -> new TopDocsAndMaxScoreCollector(true, TopScoreDocCollector.create(maxDocsPerGroup), null);
} else { } else {
supplier = () -> { supplier = () -> {
TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, getScores, true); // TODO: disable exact counts? TopFieldCollector topDocsCollector = TopFieldCollector.create(withinGroupSort, maxDocsPerGroup, true); // TODO: disable exact counts?
MaxScoreCollector maxScoreCollector = getMaxScores ? new MaxScoreCollector() : null; MaxScoreCollector maxScoreCollector = getMaxScores ? new MaxScoreCollector() : null;
return new TopDocsAndMaxScoreCollector(false, topDocsCollector, maxScoreCollector); return new TopDocsAndMaxScoreCollector(false, topDocsCollector, maxScoreCollector);
}; };

View File

@ -145,7 +145,7 @@ public class TestGrouping extends LuceneTestCase {
final FirstPassGroupingCollector<?> c1 = createRandomFirstPassCollector(groupField, groupSort, 10); final FirstPassGroupingCollector<?> c1 = createRandomFirstPassCollector(groupField, groupSort, 10);
indexSearcher.search(new TermQuery(new Term("content", "random")), c1); indexSearcher.search(new TermQuery(new Term("content", "random")), c1);
final TopGroupsCollector<?> c2 = createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true, true); final TopGroupsCollector<?> c2 = createSecondPassCollector(c1, groupSort, Sort.RELEVANCE, 0, 5, true);
indexSearcher.search(new TermQuery(new Term("content", "random")), c2); indexSearcher.search(new TermQuery(new Term("content", "random")), c2);
final TopGroups<?> groups = c2.getTopGroups(0); final TopGroups<?> groups = c2.getTopGroups(0);
@ -218,11 +218,10 @@ public class TestGrouping extends LuceneTestCase {
Sort sortWithinGroup, Sort sortWithinGroup,
int groupOffset, int groupOffset,
int maxDocsPerGroup, int maxDocsPerGroup,
boolean getScores,
boolean getMaxScores) throws IOException { boolean getMaxScores) throws IOException {
Collection<SearchGroup<T>> searchGroups = firstPassGroupingCollector.getTopGroups(groupOffset); Collection<SearchGroup<T>> searchGroups = firstPassGroupingCollector.getTopGroups(groupOffset);
return new TopGroupsCollector<>(firstPassGroupingCollector.getGroupSelector(), searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getScores, getMaxScores); return new TopGroupsCollector<>(firstPassGroupingCollector.getGroupSelector(), searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
} }
// Basically converts searchGroups from MutableValue to BytesRef if grouping by ValueSource // Basically converts searchGroups from MutableValue to BytesRef if grouping by ValueSource
@ -233,11 +232,10 @@ public class TestGrouping extends LuceneTestCase {
Sort groupSort, Sort groupSort,
Sort sortWithinGroup, Sort sortWithinGroup,
int maxDocsPerGroup, int maxDocsPerGroup,
boolean getScores,
boolean getMaxScores) throws IOException { boolean getMaxScores) throws IOException {
if (firstPassGroupingCollector.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) { if (firstPassGroupingCollector.getGroupSelector().getClass().isAssignableFrom(TermGroupSelector.class)) {
GroupSelector<BytesRef> selector = (GroupSelector<BytesRef>) firstPassGroupingCollector.getGroupSelector(); GroupSelector<BytesRef> selector = (GroupSelector<BytesRef>) firstPassGroupingCollector.getGroupSelector();
return new TopGroupsCollector<>(selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup , getScores, getMaxScores); return new TopGroupsCollector<>(selector, searchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
} else { } else {
ValueSource vs = new BytesRefFieldSource(groupField); ValueSource vs = new BytesRefFieldSource(groupField);
List<SearchGroup<MutableValue>> mvalSearchGroups = new ArrayList<>(searchGroups.size()); List<SearchGroup<MutableValue>> mvalSearchGroups = new ArrayList<>(searchGroups.size());
@ -254,7 +252,7 @@ public class TestGrouping extends LuceneTestCase {
mvalSearchGroups.add(sg); mvalSearchGroups.add(sg);
} }
ValueSourceGroupSelector selector = new ValueSourceGroupSelector(vs, new HashMap<>()); ValueSourceGroupSelector selector = new ValueSourceGroupSelector(vs, new HashMap<>());
return new TopGroupsCollector<>(selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getScores, getMaxScores); return new TopGroupsCollector<>(selector, mvalSearchGroups, groupSort, sortWithinGroup, maxDocsPerGroup, getMaxScores);
} }
} }
@ -437,7 +435,6 @@ public class TestGrouping extends LuceneTestCase {
private TopGroups<BytesRef> slowGrouping(GroupDoc[] groupDocs, private TopGroups<BytesRef> slowGrouping(GroupDoc[] groupDocs,
String searchTerm, String searchTerm,
boolean getScores,
boolean getMaxScores, boolean getMaxScores,
boolean doAllGroups, boolean doAllGroups,
Sort groupSort, Sort groupSort,
@ -507,7 +504,7 @@ public class TestGrouping extends LuceneTestCase {
for(int docIDX=docOffset; docIDX < docIDXLimit; docIDX++) { for(int docIDX=docOffset; docIDX < docIDXLimit; docIDX++) {
final GroupDoc d = docs.get(docIDX); final GroupDoc d = docs.get(docIDX);
final FieldDoc fd; final FieldDoc fd;
fd = new FieldDoc(d.id, getScores ? d.score : Float.NaN, fillFields(d, docSort)); fd = new FieldDoc(d.id, Float.NaN, fillFields(d, docSort));
hits[docIDX-docOffset] = fd; hits[docIDX-docOffset] = fd;
} }
} else { } else {
@ -829,14 +826,11 @@ public class TestGrouping extends LuceneTestCase {
} }
final String searchTerm = "real" + random().nextInt(3); final String searchTerm = "real" + random().nextInt(3);
boolean getScores = random().nextBoolean();
final boolean getMaxScores = random().nextBoolean(); final boolean getMaxScores = random().nextBoolean();
final Sort groupSort = getRandomSort(); final Sort groupSort = getRandomSort();
//final Sort groupSort = new Sort(new SortField[] {new SortField("sort1", SortField.STRING), new SortField("id", SortField.INT)}); //final Sort groupSort = new Sort(new SortField[] {new SortField("sort1", SortField.STRING), new SortField("id", SortField.INT)});
final Sort docSort = getRandomSort(); final Sort docSort = getRandomSort();
getScores |= (groupSort.needsScores() || docSort.needsScores());
final int topNGroups = TestUtil.nextInt(random(), 1, 30); final int topNGroups = TestUtil.nextInt(random(), 1, 30);
//final int topNGroups = 10; //final int topNGroups = 10;
final int docsPerGroup = TestUtil.nextInt(random(), 1, 50); final int docsPerGroup = TestUtil.nextInt(random(), 1, 50);
@ -850,7 +844,7 @@ public class TestGrouping extends LuceneTestCase {
final boolean doCache = random().nextBoolean(); final boolean doCache = random().nextBoolean();
final boolean doAllGroups = random().nextBoolean(); final boolean doAllGroups = random().nextBoolean();
if (VERBOSE) { if (VERBOSE) {
System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " dF=" + r.docFreq(new Term("content", searchTerm)) +" dFBlock=" + rBlocks.docFreq(new Term("content", searchTerm)) + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getScores=" + getScores + " getMaxScores=" + getMaxScores); System.out.println("TEST: groupSort=" + groupSort + " docSort=" + docSort + " searchTerm=" + searchTerm + " dF=" + r.docFreq(new Term("content", searchTerm)) +" dFBlock=" + rBlocks.docFreq(new Term("content", searchTerm)) + " topNGroups=" + topNGroups + " groupOffset=" + groupOffset + " docOffset=" + docOffset + " doCache=" + doCache + " docsPerGroup=" + docsPerGroup + " doAllGroups=" + doAllGroups + " getMaxScores=" + getMaxScores);
} }
String groupField = "group"; String groupField = "group";
@ -935,7 +929,7 @@ public class TestGrouping extends LuceneTestCase {
// Get 1st pass top groups using shards // Get 1st pass top groups using shards
final TopGroups<BytesRef> topGroupsShards = searchShards(s, shards.subSearchers, query, groupSort, docSort, final TopGroups<BytesRef> topGroupsShards = searchShards(s, shards.subSearchers, query, groupSort, docSort,
groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores, true, true); groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, true, true);
final TopGroupsCollector<?> c2; final TopGroupsCollector<?> c2;
if (topGroups != null) { if (topGroups != null) {
@ -946,7 +940,7 @@ public class TestGrouping extends LuceneTestCase {
} }
} }
c2 = createSecondPassCollector(c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getScores, getMaxScores); c2 = createSecondPassCollector(c1, groupSort, docSort, groupOffset, docOffset + docsPerGroup, getMaxScores);
if (doCache) { if (doCache) {
if (cCache.isCached()) { if (cCache.isCached()) {
if (VERBOSE) { if (VERBOSE) {
@ -977,7 +971,7 @@ public class TestGrouping extends LuceneTestCase {
} }
} }
final TopGroups<BytesRef> expectedGroups = slowGrouping(groupDocs, searchTerm, getScores, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset); final TopGroups<BytesRef> expectedGroups = slowGrouping(groupDocs, searchTerm, getMaxScores, doAllGroups, groupSort, docSort, topNGroups, docsPerGroup, groupOffset, docOffset);
if (VERBOSE) { if (VERBOSE) {
if (expectedGroups == null) { if (expectedGroups == null) {
@ -1023,17 +1017,16 @@ public class TestGrouping extends LuceneTestCase {
} }
} }
assertEquals(docIDToID, expectedGroups, groupsResult, true, true, getScores, true); assertEquals(docIDToID, expectedGroups, groupsResult, true, true, true);
// Confirm merged shards match: // Confirm merged shards match:
assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, getScores, true); assertEquals(docIDToID, expectedGroups, topGroupsShards, true, false, true);
if (topGroupsShards != null) { if (topGroupsShards != null) {
verifyShards(shards.docStarts, topGroupsShards); verifyShards(shards.docStarts, topGroupsShards);
} }
final boolean needsScores = getScores || getMaxScores || docSort == null; final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups,
final BlockGroupingCollector c3 = new BlockGroupingCollector(groupSort, groupOffset+topNGroups, needsScores, groupSort.needsScores() || docSort.needsScores(), sBlocks.createWeight(sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
sBlocks.createWeight(sBlocks.rewrite(lastDocInBlock), ScoreMode.COMPLETE_NO_SCORES, 1));
final AllGroupsCollector<BytesRef> allGroupsCollector2; final AllGroupsCollector<BytesRef> allGroupsCollector2;
final Collector c4; final Collector c4;
if (doAllGroups) { if (doAllGroups) {
@ -1079,7 +1072,7 @@ public class TestGrouping extends LuceneTestCase {
// Get shard'd block grouping result: // Get shard'd block grouping result:
final TopGroups<BytesRef> topGroupsBlockShards = searchShards(sBlocks, shardsBlocks.subSearchers, query, final TopGroups<BytesRef> topGroupsBlockShards = searchShards(sBlocks, shardsBlocks.subSearchers, query,
groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getScores, getMaxScores, false, false); groupSort, docSort, groupOffset, topNGroups, docOffset, docsPerGroup, getMaxScores, false, false);
if (expectedGroups != null) { if (expectedGroups != null) {
// Fixup scores for reader2 // Fixup scores for reader2
@ -1122,8 +1115,8 @@ public class TestGrouping extends LuceneTestCase {
} }
} }
assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, getScores, false); assertEquals(docIDToIDBlocks, expectedGroups, groupsResultBlocks, false, true, false);
assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, getScores, false); assertEquals(docIDToIDBlocks, expectedGroups, topGroupsBlockShards, false, false, false);
} }
r.close(); r.close();
@ -1146,7 +1139,7 @@ public class TestGrouping extends LuceneTestCase {
} }
private TopGroups<BytesRef> searchShards(IndexSearcher topSearcher, ShardSearcher[] subSearchers, Query query, Sort groupSort, Sort docSort, int groupOffset, int topNGroups, int docOffset, private TopGroups<BytesRef> searchShards(IndexSearcher topSearcher, ShardSearcher[] subSearchers, Query query, Sort groupSort, Sort docSort, int groupOffset, int topNGroups, int docOffset,
int topNDocs, boolean getScores, boolean getMaxScores, boolean canUseIDV, boolean preFlex) throws Exception { int topNDocs, boolean getMaxScores, boolean canUseIDV, boolean preFlex) throws Exception {
// TODO: swap in caching, all groups collector hereassertEquals(expected.totalHitCount, actual.totalHitCount); // TODO: swap in caching, all groups collector hereassertEquals(expected.totalHitCount, actual.totalHitCount);
// too... // too...
@ -1154,7 +1147,7 @@ public class TestGrouping extends LuceneTestCase {
System.out.println("TEST: " + subSearchers.length + " shards: " + Arrays.toString(subSearchers) + " canUseIDV=" + canUseIDV); System.out.println("TEST: " + subSearchers.length + " shards: " + Arrays.toString(subSearchers) + " canUseIDV=" + canUseIDV);
} }
// Run 1st pass collector to get top groups per shard // Run 1st pass collector to get top groups per shard
final Weight w = topSearcher.createWeight(topSearcher.rewrite(query), getScores || getMaxScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1); final Weight w = topSearcher.createWeight(topSearcher.rewrite(query), groupSort.needsScores() || docSort.needsScores() || getMaxScores ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES, 1);
final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>(); final List<Collection<SearchGroup<BytesRef>>> shardGroups = new ArrayList<>();
List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>(); List<FirstPassGroupingCollector<?>> firstPassGroupingCollectors = new ArrayList<>();
FirstPassGroupingCollector<?> firstPassCollector = null; FirstPassGroupingCollector<?> firstPassCollector = null;
@ -1208,7 +1201,7 @@ public class TestGrouping extends LuceneTestCase {
final TopGroups<BytesRef>[] shardTopGroups = new TopGroups[subSearchers.length]; final TopGroups<BytesRef>[] shardTopGroups = new TopGroups[subSearchers.length];
for(int shardIDX=0;shardIDX<subSearchers.length;shardIDX++) { for(int shardIDX=0;shardIDX<subSearchers.length;shardIDX++) {
final TopGroupsCollector<?> secondPassCollector = createSecondPassCollector(firstPassGroupingCollectors.get(shardIDX), final TopGroupsCollector<?> secondPassCollector = createSecondPassCollector(firstPassGroupingCollectors.get(shardIDX),
groupField, mergedTopGroups, groupSort, docSort, docOffset + topNDocs, getScores, getMaxScores); groupField, mergedTopGroups, groupSort, docSort, docOffset + topNDocs, getMaxScores);
subSearchers[shardIDX].search(w, secondPassCollector); subSearchers[shardIDX].search(w, secondPassCollector);
shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0); shardTopGroups[shardIDX] = getTopGroups(secondPassCollector, 0);
if (VERBOSE) { if (VERBOSE) {
@ -1232,7 +1225,7 @@ public class TestGrouping extends LuceneTestCase {
} }
} }
private void assertEquals(int[] docIDtoID, TopGroups<BytesRef> expected, TopGroups<BytesRef> actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean testScores, boolean idvBasedImplsUsed) { private void assertEquals(int[] docIDtoID, TopGroups<BytesRef> expected, TopGroups<BytesRef> actual, boolean verifyGroupValues, boolean verifyTotalGroupCount, boolean idvBasedImplsUsed) {
if (expected == null) { if (expected == null) {
assertNull(actual); assertNull(actual);
return; return;
@ -1279,12 +1272,6 @@ public class TestGrouping extends LuceneTestCase {
final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX]; final FieldDoc actualFD = (FieldDoc) actualFDs[docIDX];
//System.out.println(" actual doc=" + docIDtoID[actualFD.doc] + " score=" + actualFD.score); //System.out.println(" actual doc=" + docIDtoID[actualFD.doc] + " score=" + actualFD.score);
assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]); assertEquals(expectedFD.doc, docIDtoID[actualFD.doc]);
if (testScores) {
assertEquals(expectedFD.score, actualFD.score, 0.1);
} else {
// TODO: too anal for now
//assertEquals(Float.NaN, actualFD.score);
}
assertArrayEquals(expectedFD.fields, actualFD.fields); assertArrayEquals(expectedFD.fields, actualFD.fields);
} }
} }

View File

@ -647,7 +647,7 @@ public class AnalyzingInfixSuggester extends Lookup implements Closeable {
//System.out.println("finalQuery=" + finalQuery); //System.out.println("finalQuery=" + finalQuery);
// Sort by weight, descending: // Sort by weight, descending:
TopFieldCollector c = TopFieldCollector.create(SORT, num, false, false); TopFieldCollector c = TopFieldCollector.create(SORT, num, false);
List<LookupResult> results = null; List<LookupResult> results = null;
SearcherManager mgr; SearcherManager mgr;
IndexSearcher searcher; IndexSearcher searcher;

View File

@ -535,7 +535,7 @@ public class ExpandComponent extends SearchComponent implements PluginInfoInitia
DocIdSetIterator iterator = new BitSetIterator(groupBits, 0); // cost is not useful here DocIdSetIterator iterator = new BitSetIterator(groupBits, 0); // cost is not useful here
int group; int group;
while ((group = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { while ((group = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, false, true); Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, true);
groups.put(group, collector); groups.put(group, collector);
} }
@ -619,7 +619,7 @@ public class ExpandComponent extends SearchComponent implements PluginInfoInitia
Iterator<LongCursor> iterator = groupSet.iterator(); Iterator<LongCursor> iterator = groupSet.iterator();
while (iterator.hasNext()) { while (iterator.hasNext()) {
LongCursor cursor = iterator.next(); LongCursor cursor = iterator.next();
Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, false, true); Collector collector = (sort == null) ? TopScoreDocCollector.create(limit) : TopFieldCollector.create(sort, limit, true);
groups.put(cursor.value, collector); groups.put(cursor.value, collector);
} }

View File

@ -1327,6 +1327,7 @@ public class QueryComponent extends SearchComponent
secondPhaseBuilder.addCommandField( secondPhaseBuilder.addCommandField(
new TopGroupsFieldCommand.Builder() new TopGroupsFieldCommand.Builder()
.setQuery(cmd.getQuery())
.setField(schemaField) .setField(schemaField)
.setGroupSort(groupingSpec.getGroupSort()) .setGroupSort(groupingSpec.getGroupSort())
.setSortWithinGroup(groupingSpec.getSortWithinGroup()) .setSortWithinGroup(groupingSpec.getSortWithinGroup())

View File

@ -590,6 +590,14 @@ public class Grouping {
return null; return null;
} }
protected void populateScoresIfNecessary() throws IOException {
if (needScores) {
for (GroupDocs<?> groups : result.groups) {
TopFieldCollector.populateScores(groups.scoreDocs, searcher, query);
}
}
}
protected NamedList commonResponse() { protected NamedList commonResponse() {
NamedList groupResult = new SimpleOrderedMap(); NamedList groupResult = new SimpleOrderedMap();
grouped.add(key, groupResult); // grouped={ key={ grouped.add(key, groupResult); // grouped={ key={
@ -747,7 +755,7 @@ public class Grouping {
groupedDocsToCollect = Math.max(groupedDocsToCollect, 1); groupedDocsToCollect = Math.max(groupedDocsToCollect, 1);
Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE; Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE;
secondPass = new TopGroupsCollector<>(new TermGroupSelector(groupBy), secondPass = new TopGroupsCollector<>(new TermGroupSelector(groupBy),
topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores, needScores topGroups, groupSort, withinGroupSort, groupedDocsToCollect, needScores
); );
if (totalCount == TotalCount.grouped) { if (totalCount == TotalCount.grouped) {
@ -766,7 +774,10 @@ public class Grouping {
@Override @Override
protected void finish() throws IOException { protected void finish() throws IOException {
result = secondPass != null ? secondPass.getTopGroups(0) : null; if (secondPass != null) {
result = secondPass.getTopGroups(0);
populateScoresIfNecessary();
}
if (main) { if (main) {
mainResult = createSimpleResponse(); mainResult = createSimpleResponse();
return; return;
@ -850,7 +861,7 @@ public class Grouping {
if (withinGroupSort == null || withinGroupSort.equals(Sort.RELEVANCE)) { if (withinGroupSort == null || withinGroupSort.equals(Sort.RELEVANCE)) {
subCollector = topCollector = TopScoreDocCollector.create(groupDocsToCollect); subCollector = topCollector = TopScoreDocCollector.create(groupDocsToCollect);
} else { } else {
topCollector = TopFieldCollector.create(searcher.weightSort(withinGroupSort), groupDocsToCollect, needScores, true); topCollector = TopFieldCollector.create(searcher.weightSort(withinGroupSort), groupDocsToCollect, true);
if (needScores) { if (needScores) {
maxScoreCollector = new MaxScoreCollector(); maxScoreCollector = new MaxScoreCollector();
subCollector = MultiCollector.wrap(topCollector, maxScoreCollector); subCollector = MultiCollector.wrap(topCollector, maxScoreCollector);
@ -952,7 +963,7 @@ public class Grouping {
groupdDocsToCollect = Math.max(groupdDocsToCollect, 1); groupdDocsToCollect = Math.max(groupdDocsToCollect, 1);
Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE; Sort withinGroupSort = this.withinGroupSort != null ? this.withinGroupSort : Sort.RELEVANCE;
secondPass = new TopGroupsCollector<>(newSelector(), secondPass = new TopGroupsCollector<>(newSelector(),
topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores, needScores topGroups, groupSort, withinGroupSort, groupdDocsToCollect, needScores
); );
if (totalCount == TotalCount.grouped) { if (totalCount == TotalCount.grouped) {
@ -971,7 +982,10 @@ public class Grouping {
@Override @Override
protected void finish() throws IOException { protected void finish() throws IOException {
result = secondPass != null ? secondPass.getTopGroups(0) : null; if (secondPass != null) {
result = secondPass.getTopGroups(0);
populateScoresIfNecessary();
}
if (main) { if (main) {
mainResult = createSimpleResponse(); mainResult = createSimpleResponse();
return; return;

View File

@ -27,6 +27,7 @@ import com.carrotsearch.hppc.IntIntHashMap;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Rescorer; import org.apache.lucene.search.Rescorer;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoreMode;
@ -49,6 +50,8 @@ public class ReRankCollector extends TopDocsCollector {
final private int length; final private int length;
final private Set<BytesRef> boostedPriority; // order is the "priority" final private Set<BytesRef> boostedPriority; // order is the "priority"
final private Rescorer reRankQueryRescorer; final private Rescorer reRankQueryRescorer;
final private Sort sort;
final private Query query;
public ReRankCollector(int reRankDocs, public ReRankCollector(int reRankDocs,
@ -61,13 +64,15 @@ public class ReRankCollector extends TopDocsCollector {
this.reRankDocs = reRankDocs; this.reRankDocs = reRankDocs;
this.length = length; this.length = length;
this.boostedPriority = boostedPriority; this.boostedPriority = boostedPriority;
this.query = cmd.getQuery();
Sort sort = cmd.getSort(); Sort sort = cmd.getSort();
if(sort == null) { if(sort == null) {
this.sort = null;
this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length)); this.mainCollector = TopScoreDocCollector.create(Math.max(this.reRankDocs, length));
} else { } else {
sort = sort.rewrite(searcher); this.sort = sort = sort.rewrite(searcher);
//scores are needed for Rescorer (regardless of whether sort needs it) //scores are needed for Rescorer (regardless of whether sort needs it)
this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), true, true); this.mainCollector = TopFieldCollector.create(sort, Math.max(this.reRankDocs, length), true);
} }
this.searcher = searcher; this.searcher = searcher;
this.reRankQueryRescorer = reRankQueryRescorer; this.reRankQueryRescorer = reRankQueryRescorer;
@ -84,7 +89,7 @@ public class ReRankCollector extends TopDocsCollector {
@Override @Override
public ScoreMode scoreMode() { public ScoreMode scoreMode() {
return ScoreMode.COMPLETE; // since the scores will be needed by Rescorer as input regardless of mainCollector return sort == null || sort.needsScores() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES;
} }
public TopDocs topDocs(int start, int howMany) { public TopDocs topDocs(int start, int howMany) {
@ -97,6 +102,10 @@ public class ReRankCollector extends TopDocsCollector {
return mainDocs; return mainDocs;
} }
if (sort != null) {
TopFieldCollector.populateScores(mainDocs.scoreDocs, searcher, query);
}
ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs; ScoreDoc[] mainScoreDocs = mainDocs.scoreDocs;
ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)]; ScoreDoc[] reRankScoreDocs = new ScoreDoc[Math.min(mainScoreDocs.length, reRankDocs)];
System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length); System.arraycopy(mainScoreDocs, 0, reRankScoreDocs, 0, reRankScoreDocs.length);

View File

@ -1532,12 +1532,11 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
return TopScoreDocCollector.create(len); return TopScoreDocCollector.create(len);
} else { } else {
// we have a sort // we have a sort
final boolean needScores = (cmd.getFlags() & GET_SCORES) != 0;
final Sort weightedSort = weightSort(cmd.getSort()); final Sort weightedSort = weightSort(cmd.getSort());
final CursorMark cursor = cmd.getCursorMark(); final CursorMark cursor = cmd.getCursorMark();
final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null); final FieldDoc searchAfter = (null != cursor ? cursor.getSearchAfterFieldDoc() : null);
return TopFieldCollector.create(weightedSort, len, searchAfter, needScores, true); return TopFieldCollector.create(weightedSort, len, searchAfter, true);
} }
} }
@ -1624,6 +1623,9 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
totalHits = topCollector.getTotalHits(); totalHits = topCollector.getTotalHits();
TopDocs topDocs = topCollector.topDocs(0, len); TopDocs topDocs = topCollector.topDocs(0, len);
if (cmd.getSort() != null && query instanceof RankQuery == false && (cmd.getFlags() & GET_SCORES) != 0) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
populateNextCursorMarkFromTopDocs(qr, cmd, topDocs); populateNextCursorMarkFromTopDocs(qr, cmd, topDocs);
maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f; maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f;
@ -1732,6 +1734,9 @@ public class SolrIndexSearcher extends IndexSearcher implements Closeable, SolrI
assert (totalHits == set.size()); assert (totalHits == set.size());
TopDocs topDocs = topCollector.topDocs(0, len); TopDocs topDocs = topCollector.topDocs(0, len);
if (cmd.getSort() != null && query instanceof RankQuery == false && (cmd.getFlags() & GET_SCORES) != 0) {
TopFieldCollector.populateScores(topDocs.scoreDocs, this, query);
}
populateNextCursorMarkFromTopDocs(qr, cmd, topDocs); populateNextCursorMarkFromTopDocs(qr, cmd, topDocs);
maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f; maxScore = totalHits > 0 ? (maxScoreCollector == null ? Float.NaN : maxScoreCollector.getMaxScore()) : 0.0f;
nDocsReturned = topDocs.scoreDocs.length; nDocsReturned = topDocs.scoreDocs.length;

View File

@ -17,6 +17,7 @@
package org.apache.solr.search.grouping; package org.apache.solr.search.grouping;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import java.io.IOException; import java.io.IOException;
import java.util.List; import java.util.List;
@ -39,6 +40,12 @@ public interface Command<T> {
*/ */
List<Collector> create() throws IOException; List<Collector> create() throws IOException;
/**
* Run post-collection steps.
* @throws IOException If I/O related errors occur
*/
default void postCollect(IndexSearcher searcher) throws IOException {}
/** /**
* Returns the results that the collectors created * Returns the results that the collectors created
* by {@link #create()} contain after a search has been executed. * by {@link #create()} contain after a search has been executed.

View File

@ -163,6 +163,10 @@ public class CommandHandler {
} else { } else {
searchWithTimeLimiter(query, filter, null); searchWithTimeLimiter(query, filter, null);
} }
for (Command command : commands) {
command.postCollect(searcher);
}
} }
private DocSet computeGroupedDocSet(Query query, ProcessedFilter filter, List<Collector> collectors) throws IOException { private DocSet computeGroupedDocSet(Query query, ProcessedFilter filter, List<Collector> collectors) throws IOException {

View File

@ -114,6 +114,7 @@ public class QueryCommand implements Command<QueryCommandResult> {
private TopDocsCollector topDocsCollector; private TopDocsCollector topDocsCollector;
private FilterCollector filterCollector; private FilterCollector filterCollector;
private MaxScoreCollector maxScoreCollector; private MaxScoreCollector maxScoreCollector;
private TopDocs topDocs;
private QueryCommand(Sort sort, Query query, int docsToCollect, boolean needScores, DocSet docSet, String queryString) { private QueryCommand(Sort sort, Query query, int docsToCollect, boolean needScores, DocSet docSet, String queryString) {
this.sort = sort; this.sort = sort;
@ -130,7 +131,7 @@ public class QueryCommand implements Command<QueryCommandResult> {
if (sort == null || sort.equals(Sort.RELEVANCE)) { if (sort == null || sort.equals(Sort.RELEVANCE)) {
subCollector = topDocsCollector = TopScoreDocCollector.create(docsToCollect); subCollector = topDocsCollector = TopScoreDocCollector.create(docsToCollect);
} else { } else {
topDocsCollector = TopFieldCollector.create(sort, docsToCollect, needScores, true); topDocsCollector = TopFieldCollector.create(sort, docsToCollect, true);
if (needScores) { if (needScores) {
maxScoreCollector = new MaxScoreCollector(); maxScoreCollector = new MaxScoreCollector();
subCollector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); subCollector = MultiCollector.wrap(topDocsCollector, maxScoreCollector);
@ -143,8 +144,15 @@ public class QueryCommand implements Command<QueryCommandResult> {
} }
@Override @Override
public QueryCommandResult result() { public void postCollect(IndexSearcher searcher) throws IOException {
TopDocs topDocs = topDocsCollector.topDocs(); topDocs = topDocsCollector.topDocs();
if (needScores) {
TopFieldCollector.populateScores(topDocs.scoreDocs, searcher, query);
}
}
@Override
public QueryCommandResult result() throws IOException {
float maxScore; float maxScore;
if (sort == null) { if (sort == null) {
maxScore = topDocs.scoreDocs.length == 0 ? Float.NaN : topDocs.scoreDocs[0].score; maxScore = topDocs.scoreDocs.length == 0 ? Float.NaN : topDocs.scoreDocs[0].score;

View File

@ -25,7 +25,10 @@ import java.util.List;
import org.apache.lucene.queries.function.ValueSource; import org.apache.lucene.queries.function.ValueSource;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort; import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.grouping.GroupDocs; import org.apache.lucene.search.grouping.GroupDocs;
import org.apache.lucene.search.grouping.SearchGroup; import org.apache.lucene.search.grouping.SearchGroup;
import org.apache.lucene.search.grouping.TermGroupSelector; import org.apache.lucene.search.grouping.TermGroupSelector;
@ -45,6 +48,7 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
public static class Builder { public static class Builder {
private Query query;
private SchemaField field; private SchemaField field;
private Sort groupSort; private Sort groupSort;
private Sort withinGroupSort; private Sort withinGroupSort;
@ -53,6 +57,11 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
private boolean needScores = false; private boolean needScores = false;
private boolean needMaxScore = false; private boolean needMaxScore = false;
public Builder setQuery(Query query) {
this.query = query;
return this;
}
public Builder setField(SchemaField field) { public Builder setField(SchemaField field) {
this.field = field; this.field = field;
return this; return this;
@ -89,16 +98,17 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
} }
public TopGroupsFieldCommand build() { public TopGroupsFieldCommand build() {
if (field == null || groupSort == null || withinGroupSort == null || firstPhaseGroups == null || if (query == null || field == null || groupSort == null || withinGroupSort == null || firstPhaseGroups == null ||
maxDocPerGroup == null) { maxDocPerGroup == null) {
throw new IllegalStateException("All required fields must be set"); throw new IllegalStateException("All required fields must be set");
} }
return new TopGroupsFieldCommand(field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore); return new TopGroupsFieldCommand(query, field, groupSort, withinGroupSort, firstPhaseGroups, maxDocPerGroup, needScores, needMaxScore);
} }
} }
private final Query query;
private final SchemaField field; private final SchemaField field;
private final Sort groupSort; private final Sort groupSort;
private final Sort withinGroupSort; private final Sort withinGroupSort;
@ -107,14 +117,17 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
private final boolean needScores; private final boolean needScores;
private final boolean needMaxScore; private final boolean needMaxScore;
private TopGroupsCollector secondPassCollector; private TopGroupsCollector secondPassCollector;
private TopGroups<BytesRef> topGroups;
private TopGroupsFieldCommand(SchemaField field, private TopGroupsFieldCommand(Query query,
SchemaField field,
Sort groupSort, Sort groupSort,
Sort withinGroupSort, Sort withinGroupSort,
Collection<SearchGroup<BytesRef>> firstPhaseGroups, Collection<SearchGroup<BytesRef>> firstPhaseGroups,
int maxDocPerGroup, int maxDocPerGroup,
boolean needScores, boolean needScores,
boolean needMaxScore) { boolean needMaxScore) {
this.query = query;
this.field = field; this.field = field;
this.groupSort = groupSort; this.groupSort = groupSort;
this.withinGroupSort = withinGroupSort; this.withinGroupSort = withinGroupSort;
@ -136,11 +149,11 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
ValueSource vs = fieldType.getValueSource(field, null); ValueSource vs = fieldType.getValueSource(field, null);
Collection<SearchGroup<MutableValue>> v = GroupConverter.toMutable(field, firstPhaseGroups); Collection<SearchGroup<MutableValue>> v = GroupConverter.toMutable(field, firstPhaseGroups);
secondPassCollector = new TopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()), secondPassCollector = new TopGroupsCollector<>(new ValueSourceGroupSelector(vs, new HashMap<>()),
v, groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore v, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore
); );
} else { } else {
secondPassCollector = new TopGroupsCollector<>(new TermGroupSelector(field.getName()), secondPassCollector = new TopGroupsCollector<>(new TermGroupSelector(field.getName()),
firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needScores, needMaxScore firstPhaseGroups, groupSort, withinGroupSort, maxDocPerGroup, needMaxScore
); );
} }
collectors.add(secondPassCollector); collectors.add(secondPassCollector);
@ -148,18 +161,27 @@ public class TopGroupsFieldCommand implements Command<TopGroups<BytesRef>> {
} }
@Override @Override
@SuppressWarnings("unchecked") public void postCollect(IndexSearcher searcher) throws IOException {
public TopGroups<BytesRef> result() {
if (firstPhaseGroups.isEmpty()) { if (firstPhaseGroups.isEmpty()) {
return new TopGroups<>(groupSort.getSort(), withinGroupSort.getSort(), 0, 0, new GroupDocs[0], Float.NaN); topGroups = new TopGroups<>(groupSort.getSort(), withinGroupSort.getSort(), 0, 0, new GroupDocs[0], Float.NaN);
return;
} }
FieldType fieldType = field.getType(); FieldType fieldType = field.getType();
if (fieldType.getNumberType() != null) { if (fieldType.getNumberType() != null) {
return GroupConverter.fromMutable(field, secondPassCollector.getTopGroups(0)); topGroups = GroupConverter.fromMutable(field, secondPassCollector.getTopGroups(0));
} else { } else {
return secondPassCollector.getTopGroups(0); topGroups = secondPassCollector.getTopGroups(0);
} }
for (GroupDocs<?> group : topGroups.groups) {
TopFieldCollector.populateScores(group.scoreDocs, searcher, query);
}
}
@Override
@SuppressWarnings("unchecked")
public TopGroups<BytesRef> result() throws IOException {
return topGroups;
} }
@Override @Override

View File

@ -282,9 +282,8 @@ public class TestSort extends SolrTestCaseJ4 {
final String nullRep = luceneSort || sortMissingFirst && !reverse || sortMissingLast && reverse ? "" : "zzz"; final String nullRep = luceneSort || sortMissingFirst && !reverse || sortMissingLast && reverse ? "" : "zzz";
final String nullRep2 = luceneSort2 || sortMissingFirst2 && !reverse2 || sortMissingLast2 && reverse2 ? "" : "zzz"; final String nullRep2 = luceneSort2 || sortMissingFirst2 && !reverse2 || sortMissingLast2 && reverse2 ? "" : "zzz";
boolean trackScores = r.nextBoolean();
boolean scoreInOrder = r.nextBoolean(); boolean scoreInOrder = r.nextBoolean();
final TopFieldCollector topCollector = TopFieldCollector.create(sort, top, trackScores, true); final TopFieldCollector topCollector = TopFieldCollector.create(sort, top, true);
final List<MyDoc> collectedDocs = new ArrayList<>(); final List<MyDoc> collectedDocs = new ArrayList<>();
// delegate and collect docs ourselves // delegate and collect docs ourselves

View File

@ -168,9 +168,9 @@ public class TestFieldCacheSortRandom extends LuceneTestCase {
int queryType = random.nextInt(2); int queryType = random.nextInt(2);
if (queryType == 0) { if (queryType == 0) {
hits = s.search(new ConstantScoreQuery(f), hits = s.search(new ConstantScoreQuery(f),
hitCount, sort, random.nextBoolean()); hitCount, sort, false);
} else { } else {
hits = s.search(f, hitCount, sort, random.nextBoolean()); hits = s.search(f, hitCount, sort, false);
} }
if (VERBOSE) { if (VERBOSE) {