Avoid double term construction in DfsPhase (#38716)
DfsPhase captures terms used for scoring a query in order to build global term statistics across multiple shards for more accurate scoring. It currently does this by building the query's `Weight` and calling `extractTerms` on it to collect terms, and then calling `IndexSearcher.termStatistics()` for each collected term. This duplicates work, however, as the various `Weight` implementations will already have collected these statistics at construction time. This commit replaces this round-about way of collecting stats, instead using a delegating IndexSearcher that collects the term contexts and statistics when `IndexSearcher.termStatistics()` is called from the Weight. It also fixes a bug when using rescorers, where a `QueryRescorer` would calculate distributed term statistics, but ignore field statistics. `Rescorer.extractTerms` has been removed, and replaced with a new method on `RescoreContext` that returns any queries used by the rescore implementation. The delegating IndexSearcher then collects term contexts and statistics in the same way described above for each Query.
This commit is contained in:
parent
27cf7e27e7
commit
176013e23c
|
@ -20,7 +20,6 @@
|
|||
package org.elasticsearch.example.rescore;
|
||||
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
|
@ -46,7 +45,6 @@ import java.io.IOException;
|
|||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.Objects;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.Collections.singletonList;
|
||||
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
|
||||
|
@ -224,9 +222,5 @@ public class ExampleRescoreBuilder extends RescorerBuilder<ExampleRescoreBuilder
|
|||
return Explanation.match(context.factor, "test", singletonList(sourceExplanation));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) {
|
||||
// Since we don't use queries there are no terms to extract.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,14 +19,12 @@
|
|||
|
||||
package org.elasticsearch.search.dfs;
|
||||
|
||||
import com.carrotsearch.hppc.ObjectHashSet;
|
||||
import com.carrotsearch.hppc.ObjectObjectHashMap;
|
||||
import com.carrotsearch.hppc.cursors.ObjectCursor;
|
||||
|
||||
import org.apache.lucene.index.IndexReaderContext;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.TermStates;
|
||||
import org.apache.lucene.search.CollectionStatistics;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.TermStatistics;
|
||||
import org.elasticsearch.common.collect.HppcMaps;
|
||||
|
@ -36,9 +34,8 @@ import org.elasticsearch.search.rescore.RescoreContext;
|
|||
import org.elasticsearch.tasks.TaskCancelledException;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.AbstractSet;
|
||||
import java.util.Collection;
|
||||
import java.util.Iterator;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
|
||||
|
@ -52,42 +49,46 @@ public class DfsPhase implements SearchPhase {
|
|||
|
||||
@Override
|
||||
public void execute(SearchContext context) {
|
||||
final ObjectHashSet<Term> termsSet = new ObjectHashSet<>();
|
||||
try {
|
||||
context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1f)
|
||||
.extractTerms(new DelegateSet(termsSet));
|
||||
for (RescoreContext rescoreContext : context.rescore()) {
|
||||
try {
|
||||
rescoreContext.rescorer().extractTerms(context.searcher(), rescoreContext, new DelegateSet(termsSet));
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Failed to extract terms", e);
|
||||
}
|
||||
}
|
||||
|
||||
Term[] terms = termsSet.toArray(Term.class);
|
||||
TermStatistics[] termStatistics = new TermStatistics[terms.length];
|
||||
IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
|
||||
for (int i = 0; i < terms.length; i++) {
|
||||
if(context.isCancelled()) {
|
||||
throw new TaskCancelledException("cancelled");
|
||||
}
|
||||
// LUCENE 4 UPGRADE: cache TermStates?
|
||||
TermStates termContext = TermStates.build(indexReaderContext, terms[i], true);
|
||||
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
|
||||
}
|
||||
|
||||
ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
|
||||
for (Term term : terms) {
|
||||
assert term.field() != null : "field is null";
|
||||
if (fieldStatistics.containsKey(term.field()) == false) {
|
||||
final CollectionStatistics collectionStatistics = context.searcher().collectionStatistics(term.field());
|
||||
if (collectionStatistics != null) {
|
||||
fieldStatistics.put(term.field(), collectionStatistics);
|
||||
}
|
||||
if(context.isCancelled()) {
|
||||
Map<Term, TermStatistics> stats = new HashMap<>();
|
||||
IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
|
||||
@Override
|
||||
public TermStatistics termStatistics(Term term, TermStates states) throws IOException {
|
||||
if (context.isCancelled()) {
|
||||
throw new TaskCancelledException("cancelled");
|
||||
}
|
||||
TermStatistics ts = super.termStatistics(term, states);
|
||||
if (ts != null) {
|
||||
stats.put(term, ts);
|
||||
}
|
||||
return ts;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CollectionStatistics collectionStatistics(String field) throws IOException {
|
||||
if (context.isCancelled()) {
|
||||
throw new TaskCancelledException("cancelled");
|
||||
}
|
||||
CollectionStatistics cs = super.collectionStatistics(field);
|
||||
if (cs != null) {
|
||||
fieldStatistics.put(field, cs);
|
||||
}
|
||||
return cs;
|
||||
}
|
||||
};
|
||||
|
||||
searcher.createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1);
|
||||
for (RescoreContext rescoreContext : context.rescore()) {
|
||||
for (Query query : rescoreContext.getQueries()) {
|
||||
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
|
||||
}
|
||||
}
|
||||
|
||||
Term[] terms = stats.keySet().toArray(new Term[0]);
|
||||
TermStatistics[] termStatistics = new TermStatistics[terms.length];
|
||||
for (int i = 0; i < terms.length; i++) {
|
||||
termStatistics[i] = stats.get(terms[i]);
|
||||
}
|
||||
|
||||
context.dfsResult().termsStatistics(terms, termStatistics)
|
||||
|
@ -95,58 +96,6 @@ public class DfsPhase implements SearchPhase {
|
|||
.maxDoc(context.searcher().getIndexReader().maxDoc());
|
||||
} catch (Exception e) {
|
||||
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
|
||||
} finally {
|
||||
termsSet.clear(); // don't hold on to terms
|
||||
}
|
||||
}
|
||||
|
||||
// We need to bridge to JCF world, b/c of Query#extractTerms
|
||||
private static class DelegateSet extends AbstractSet<Term> {
|
||||
|
||||
private final ObjectHashSet<Term> delegate;
|
||||
|
||||
private DelegateSet(ObjectHashSet<Term> delegate) {
|
||||
this.delegate = delegate;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean add(Term term) {
|
||||
return delegate.add(term);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean addAll(Collection<? extends Term> terms) {
|
||||
boolean result = false;
|
||||
for (Term term : terms) {
|
||||
result = delegate.add(term);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Iterator<Term> iterator() {
|
||||
final Iterator<ObjectCursor<Term>> iterator = delegate.iterator();
|
||||
return new Iterator<Term>() {
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return iterator.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Term next() {
|
||||
return iterator.next().value;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public int size() {
|
||||
return delegate.size();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,19 +19,19 @@
|
|||
|
||||
package org.elasticsearch.search.rescore;
|
||||
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.Query;
|
||||
import org.apache.lucene.search.ScoreDoc;
|
||||
import org.apache.lucene.search.ScoreMode;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.Comparator;
|
||||
import java.util.Set;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static java.util.stream.Collectors.toSet;
|
||||
|
||||
public final class QueryRescorer implements Rescorer {
|
||||
|
@ -170,6 +170,11 @@ public final class QueryRescorer implements Rescorer {
|
|||
this.query = query;
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Query> getQueries() {
|
||||
return Collections.singletonList(query);
|
||||
}
|
||||
|
||||
public Query query() {
|
||||
return query;
|
||||
}
|
||||
|
@ -203,10 +208,4 @@ public final class QueryRescorer implements Rescorer {
|
|||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException {
|
||||
Query query = ((QueryRescoreContext) rescoreContext).query();
|
||||
searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f).extractTerms(termsSet);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -19,6 +19,10 @@
|
|||
|
||||
package org.elasticsearch.search.rescore;
|
||||
|
||||
import org.apache.lucene.search.Query;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
|
@ -29,7 +33,7 @@ import java.util.Set;
|
|||
public class RescoreContext {
|
||||
private final int windowSize;
|
||||
private final Rescorer rescorer;
|
||||
private Set<Integer> resroredDocs; //doc Ids for which rescoring was applied
|
||||
private Set<Integer> rescoredDocs; //doc Ids for which rescoring was applied
|
||||
|
||||
/**
|
||||
* Build the context.
|
||||
|
@ -55,10 +59,17 @@ public class RescoreContext {
|
|||
}
|
||||
|
||||
public void setRescoredDocs(Set<Integer> docIds) {
|
||||
resroredDocs = docIds;
|
||||
rescoredDocs = docIds;
|
||||
}
|
||||
|
||||
public boolean isRescored(int docId) {
|
||||
return resroredDocs.contains(docId);
|
||||
return rescoredDocs.contains(docId);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns queries associated with the rescorer
|
||||
*/
|
||||
public List<Query> getQueries() {
|
||||
return Collections.emptyList();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,14 +19,11 @@
|
|||
|
||||
package org.elasticsearch.search.rescore;
|
||||
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.search.Explanation;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.TopDocs;
|
||||
import org.elasticsearch.action.search.SearchType;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Set;
|
||||
|
||||
/**
|
||||
* A query rescorer interface used to re-rank the Top-K results of a previously
|
||||
|
@ -61,10 +58,4 @@ public interface Rescorer {
|
|||
Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext,
|
||||
Explanation sourceExplanation) throws IOException;
|
||||
|
||||
/**
|
||||
* Extracts all terms needed to execute this {@link Rescorer}. This method
|
||||
* is executed in a distributed frequency collection roundtrip for
|
||||
* {@link SearchType#DFS_QUERY_THEN_FETCH}
|
||||
*/
|
||||
void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue