Avoid double term construction in DfsPhase (#38716)

DfsPhase captures terms used for scoring a query in order to build global term statistics across
multiple shards for more accurate scoring. It currently does this by building the query's `Weight`
and calling `extractTerms` on it to collect terms, and then calling `IndexSearcher.termStatistics()`
for each collected term. This duplicates work, however, as the various `Weight` implementations 
will already have collected these statistics at construction time.

This commit replaces this round-about way of collecting stats, instead using a delegating
IndexSearcher that collects the term contexts and statistics when `IndexSearcher.termStatistics()`
is called from the Weight.

It also fixes a bug when using rescorers, where a `QueryRescorer` would calculate distributed term
statistics, but ignore field statistics.  `Rescorer.extractTerms` has been removed, and replaced with
a new method on `RescoreContext` that returns any queries used by the rescore implementation.
The delegating IndexSearcher then collects term contexts and statistics in the same way described
above for each Query.
This commit is contained in:
Alan Woodward 2019-02-15 15:42:04 +00:00 committed by Alan Woodward
parent 27cf7e27e7
commit 176013e23c
5 changed files with 62 additions and 118 deletions

View File

@ -20,7 +20,6 @@
package org.elasticsearch.example.rescore; package org.elasticsearch.example.rescore;
import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
@ -46,7 +45,6 @@ import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator; import java.util.Iterator;
import java.util.Objects; import java.util.Objects;
import java.util.Set;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg; import static org.elasticsearch.common.xcontent.ConstructingObjectParser.constructorArg;
@ -224,9 +222,5 @@ public class ExampleRescoreBuilder extends RescorerBuilder<ExampleRescoreBuilder
return Explanation.match(context.factor, "test", singletonList(sourceExplanation)); return Explanation.match(context.factor, "test", singletonList(sourceExplanation));
} }
@Override
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) {
// Since we don't use queries there are no terms to extract.
}
} }
} }

View File

@ -19,14 +19,12 @@
package org.elasticsearch.search.dfs; package org.elasticsearch.search.dfs;
import com.carrotsearch.hppc.ObjectHashSet;
import com.carrotsearch.hppc.ObjectObjectHashMap; import com.carrotsearch.hppc.ObjectObjectHashMap;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.Term; import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates; import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.CollectionStatistics; import org.apache.lucene.search.CollectionStatistics;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermStatistics; import org.apache.lucene.search.TermStatistics;
import org.elasticsearch.common.collect.HppcMaps; import org.elasticsearch.common.collect.HppcMaps;
@ -36,9 +34,8 @@ import org.elasticsearch.search.rescore.RescoreContext;
import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskCancelledException;
import java.io.IOException; import java.io.IOException;
import java.util.AbstractSet; import java.util.HashMap;
import java.util.Collection; import java.util.Map;
import java.util.Iterator;
/** /**
* Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase. * Dfs phase of a search request, used to make scoring 100% accurate by collecting additional info from each shard before the query phase.
@ -52,42 +49,46 @@ public class DfsPhase implements SearchPhase {
@Override @Override
public void execute(SearchContext context) { public void execute(SearchContext context) {
final ObjectHashSet<Term> termsSet = new ObjectHashSet<>();
try { try {
context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1f)
.extractTerms(new DelegateSet(termsSet));
for (RescoreContext rescoreContext : context.rescore()) {
try {
rescoreContext.rescorer().extractTerms(context.searcher(), rescoreContext, new DelegateSet(termsSet));
} catch (IOException e) {
throw new IllegalStateException("Failed to extract terms", e);
}
}
Term[] terms = termsSet.toArray(Term.class);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
IndexReaderContext indexReaderContext = context.searcher().getTopReaderContext();
for (int i = 0; i < terms.length; i++) {
if(context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
// LUCENE 4 UPGRADE: cache TermStates?
TermStates termContext = TermStates.build(indexReaderContext, terms[i], true);
termStatistics[i] = context.searcher().termStatistics(terms[i], termContext);
}
ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap(); ObjectObjectHashMap<String, CollectionStatistics> fieldStatistics = HppcMaps.newNoNullKeysMap();
for (Term term : terms) { Map<Term, TermStatistics> stats = new HashMap<>();
assert term.field() != null : "field is null"; IndexSearcher searcher = new IndexSearcher(context.searcher().getIndexReader()) {
if (fieldStatistics.containsKey(term.field()) == false) { @Override
final CollectionStatistics collectionStatistics = context.searcher().collectionStatistics(term.field()); public TermStatistics termStatistics(Term term, TermStates states) throws IOException {
if (collectionStatistics != null) { if (context.isCancelled()) {
fieldStatistics.put(term.field(), collectionStatistics);
}
if(context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled");
} }
TermStatistics ts = super.termStatistics(term, states);
if (ts != null) {
stats.put(term, ts);
}
return ts;
} }
@Override
public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) {
fieldStatistics.put(field, cs);
}
return cs;
}
};
searcher.createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE, 1);
for (RescoreContext rescoreContext : context.rescore()) {
for (Query query : rescoreContext.getQueries()) {
searcher.createWeight(context.searcher().rewrite(query), ScoreMode.COMPLETE, 1);
}
}
Term[] terms = stats.keySet().toArray(new Term[0]);
TermStatistics[] termStatistics = new TermStatistics[terms.length];
for (int i = 0; i < terms.length; i++) {
termStatistics[i] = stats.get(terms[i]);
} }
context.dfsResult().termsStatistics(terms, termStatistics) context.dfsResult().termsStatistics(terms, termStatistics)
@ -95,58 +96,6 @@ public class DfsPhase implements SearchPhase {
.maxDoc(context.searcher().getIndexReader().maxDoc()); .maxDoc(context.searcher().getIndexReader().maxDoc());
} catch (Exception e) { } catch (Exception e) {
throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e); throw new DfsPhaseExecutionException(context, "Exception during dfs phase", e);
} finally {
termsSet.clear(); // don't hold on to terms
}
}
// We need to bridge to JCF world, b/c of Query#extractTerms
private static class DelegateSet extends AbstractSet<Term> {
private final ObjectHashSet<Term> delegate;
private DelegateSet(ObjectHashSet<Term> delegate) {
this.delegate = delegate;
}
@Override
public boolean add(Term term) {
return delegate.add(term);
}
@Override
public boolean addAll(Collection<? extends Term> terms) {
boolean result = false;
for (Term term : terms) {
result = delegate.add(term);
}
return result;
}
@Override
public Iterator<Term> iterator() {
final Iterator<ObjectCursor<Term>> iterator = delegate.iterator();
return new Iterator<Term>() {
@Override
public boolean hasNext() {
return iterator.hasNext();
}
@Override
public Term next() {
return iterator.next().value;
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
};
}
@Override
public int size() {
return delegate.size();
} }
} }

View File

@ -19,19 +19,19 @@
package org.elasticsearch.search.rescore; package org.elasticsearch.search.rescore;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Comparator;
import java.util.Set;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toSet;
public final class QueryRescorer implements Rescorer { public final class QueryRescorer implements Rescorer {
@ -170,6 +170,11 @@ public final class QueryRescorer implements Rescorer {
this.query = query; this.query = query;
} }
@Override
public List<Query> getQueries() {
return Collections.singletonList(query);
}
public Query query() { public Query query() {
return query; return query;
} }
@ -203,10 +208,4 @@ public final class QueryRescorer implements Rescorer {
} }
} }
@Override
public void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException {
Query query = ((QueryRescoreContext) rescoreContext).query();
searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f).extractTerms(termsSet);
}
} }

View File

@ -19,6 +19,10 @@
package org.elasticsearch.search.rescore; package org.elasticsearch.search.rescore;
import org.apache.lucene.search.Query;
import java.util.Collections;
import java.util.List;
import java.util.Set; import java.util.Set;
/** /**
@ -29,7 +33,7 @@ import java.util.Set;
public class RescoreContext { public class RescoreContext {
private final int windowSize; private final int windowSize;
private final Rescorer rescorer; private final Rescorer rescorer;
private Set<Integer> resroredDocs; //doc Ids for which rescoring was applied private Set<Integer> rescoredDocs; //doc Ids for which rescoring was applied
/** /**
* Build the context. * Build the context.
@ -55,10 +59,17 @@ public class RescoreContext {
} }
public void setRescoredDocs(Set<Integer> docIds) { public void setRescoredDocs(Set<Integer> docIds) {
resroredDocs = docIds; rescoredDocs = docIds;
} }
public boolean isRescored(int docId) { public boolean isRescored(int docId) {
return resroredDocs.contains(docId); return rescoredDocs.contains(docId);
}
/**
* Returns queries associated with the rescorer
*/
public List<Query> getQueries() {
return Collections.emptyList();
} }
} }

View File

@ -19,14 +19,11 @@
package org.elasticsearch.search.rescore; package org.elasticsearch.search.rescore;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopDocs;
import org.elasticsearch.action.search.SearchType;
import java.io.IOException; import java.io.IOException;
import java.util.Set;
/** /**
* A query rescorer interface used to re-rank the Top-K results of a previously * A query rescorer interface used to re-rank the Top-K results of a previously
@ -61,10 +58,4 @@ public interface Rescorer {
Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext, Explanation explain(int topLevelDocId, IndexSearcher searcher, RescoreContext rescoreContext,
Explanation sourceExplanation) throws IOException; Explanation sourceExplanation) throws IOException;
/**
* Extracts all terms needed to execute this {@link Rescorer}. This method
* is executed in a distributed frequency collection roundtrip for
* {@link SearchType#DFS_QUERY_THEN_FETCH}
*/
void extractTerms(IndexSearcher searcher, RescoreContext rescoreContext, Set<Term> termsSet) throws IOException;
} }