Remove unnecessary fields loop from extractWeightedSpanTerms() (#12965)

This commit is contained in:
sabi0 2024-01-08 22:01:56 +01:00 committed by GitHub
parent 376bd24693
commit a32f6acadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 73 deletions

View File

@ -100,8 +100,8 @@ import org.apache.lucene.util.IOUtils;
public class WeightedSpanTermExtractor { public class WeightedSpanTermExtractor {
private String fieldName; private String fieldName;
private TokenStream tokenStream; // set subsequent to getWeightedSpanTerms* methods private TokenStream tokenStream; // set after getWeightedSpanTerms* methods
private String defaultField; private final String defaultField;
private boolean expandMultiTermQuery; private boolean expandMultiTermQuery;
private boolean cachedTokenStream; private boolean cachedTokenStream;
private boolean wrapToCaching = true; private boolean wrapToCaching = true;
@ -244,7 +244,6 @@ public class WeightedSpanTermExtractor {
&& (!expandMultiTermQuery || !fieldNameComparator(((MultiTermQuery) query).getField()))) { && (!expandMultiTermQuery || !fieldNameComparator(((MultiTermQuery) query).getField()))) {
return; return;
} }
Query origQuery = query;
final IndexReader reader = getLeafContext().reader(); final IndexReader reader = getLeafContext().reader();
Query rewritten; Query rewritten;
if (query instanceof MultiTermQuery) { if (query instanceof MultiTermQuery) {
@ -252,12 +251,11 @@ public class WeightedSpanTermExtractor {
MultiTermQuery.SCORING_BOOLEAN_REWRITE.rewrite( MultiTermQuery.SCORING_BOOLEAN_REWRITE.rewrite(
new IndexSearcher(reader), (MultiTermQuery) query); new IndexSearcher(reader), (MultiTermQuery) query);
} else { } else {
rewritten = origQuery.rewrite(new IndexSearcher(reader)); rewritten = query.rewrite(new IndexSearcher(reader));
} }
if (rewritten != origQuery) { if (rewritten != query) {
// only rewrite once and then flatten again - the rewritten query could have a special // only rewrite once and then flatten again - the rewritten query could have a special
// treatment // treatment if this method is overwritten in a subclass or above in the next recursion
// if this method is overwritten in a subclass or above in the next recursion
extract(rewritten, boost, terms); extract(rewritten, boost, terms);
} else { } else {
extractUnknownQuery(query, terms); extractUnknownQuery(query, terms);
@ -293,54 +291,38 @@ public class WeightedSpanTermExtractor {
*/ */
protected void extractWeightedSpanTerms( protected void extractWeightedSpanTerms(
Map<String, WeightedSpanTerm> terms, SpanQuery spanQuery, float boost) throws IOException { Map<String, WeightedSpanTerm> terms, SpanQuery spanQuery, float boost) throws IOException {
Set<String> fieldNames;
if (fieldName == null) { Set<String> queryFieldNames = new HashSet<>();
fieldNames = new HashSet<>(); collectSpanQueryFields(spanQuery, queryFieldNames);
collectSpanQueryFields(spanQuery, fieldNames); if (fieldName != null
} else { && queryFieldNames.contains(fieldName) == false
fieldNames = new HashSet<>(1); && (defaultField == null || queryFieldNames.contains(defaultField) == false)) {
fieldNames.add(fieldName); return;
}
// To support the use of the default field name
if (defaultField != null) {
fieldNames.add(defaultField);
} }
Map<String, SpanQuery> queries = new HashMap<>();
Set<Term> nonWeightedTerms = new HashSet<>();
final boolean mustRewriteQuery = mustRewriteQuery(spanQuery); final boolean mustRewriteQuery = mustRewriteQuery(spanQuery);
final IndexSearcher searcher = new IndexSearcher(getLeafContext()); final IndexSearcher searcher = new IndexSearcher(getLeafContext());
searcher.setQueryCache(null); searcher.setQueryCache(null);
if (mustRewriteQuery) { final SpanQuery query = mustRewriteQuery ? (SpanQuery) searcher.rewrite(spanQuery) : spanQuery;
final SpanQuery rewrittenQuery = (SpanQuery) searcher.rewrite(spanQuery);
for (final String field : fieldNames) { final Set<Term> nonWeightedTerms = new HashSet<>();
queries.put(field, rewrittenQuery); query.visit(QueryVisitor.termCollector(nonWeightedTerms));
} if (nonWeightedTerms.isEmpty()) {
rewrittenQuery.visit(QueryVisitor.termCollector(nonWeightedTerms)); return;
} else {
spanQuery.visit(QueryVisitor.termCollector(nonWeightedTerms));
} }
List<PositionSpan> spanPositions = new ArrayList<>(); final List<PositionSpan> spanPositions = new ArrayList<>();
for (final String field : fieldNames) {
final SpanQuery q;
if (mustRewriteQuery) {
q = queries.get(field);
} else {
q = spanQuery;
}
LeafReaderContext context = getLeafContext(); LeafReaderContext context = getLeafContext();
SpanWeight w = SpanWeight w =
(SpanWeight) searcher.createWeight(searcher.rewrite(q), ScoreMode.COMPLETE_NO_SCORES, 1); (SpanWeight)
Bits acceptDocs = context.reader().getLiveDocs(); searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1);
final Spans spans = w.getSpans(context, SpanWeight.Postings.POSITIONS); final Spans spans = w.getSpans(context, SpanWeight.Postings.POSITIONS);
if (spans == null) { if (spans == null) {
return; return;
} }
final Bits acceptDocs = context.reader().getLiveDocs();
// collect span positions // collect span positions
while (spans.nextDoc() != Spans.NO_MORE_DOCS) { while (spans.nextDoc() != Spans.NO_MORE_DOCS) {
if (acceptDocs != null && acceptDocs.get(spans.docID()) == false) { if (acceptDocs != null && acceptDocs.get(spans.docID()) == false) {
@ -350,10 +332,8 @@ public class WeightedSpanTermExtractor {
spanPositions.add(new PositionSpan(spans.startPosition(), spans.endPosition() - 1)); spanPositions.add(new PositionSpan(spans.startPosition(), spans.endPosition() - 1));
} }
} }
}
if (spanPositions.size() == 0) { if (spanPositions.isEmpty()) {
// no spans found
return; return;
} }
@ -401,11 +381,9 @@ public class WeightedSpanTermExtractor {
/** Necessary to implement matches for queries against <code>defaultField</code> */ /** Necessary to implement matches for queries against <code>defaultField</code> */
protected boolean fieldNameComparator(String fieldNameToCheck) { protected boolean fieldNameComparator(String fieldNameToCheck) {
boolean rv = return fieldName == null
fieldName == null
|| fieldName.equals(fieldNameToCheck) || fieldName.equals(fieldNameToCheck)
|| (defaultField != null && defaultField.equals(fieldNameToCheck)); || (defaultField != null && defaultField.equals(fieldNameToCheck));
return rv;
} }
protected LeafReaderContext getLeafContext() throws IOException { protected LeafReaderContext getLeafContext() throws IOException {
@ -555,11 +533,7 @@ public class WeightedSpanTermExtractor {
public Map<String, WeightedSpanTerm> getWeightedSpanTermsWithScores( public Map<String, WeightedSpanTerm> getWeightedSpanTermsWithScores(
Query query, float boost, TokenStream tokenStream, String fieldName, IndexReader reader) Query query, float boost, TokenStream tokenStream, String fieldName, IndexReader reader)
throws IOException { throws IOException {
if (fieldName != null) {
this.fieldName = fieldName; this.fieldName = fieldName;
} else {
this.fieldName = null;
}
this.tokenStream = tokenStream; this.tokenStream = tokenStream;
Map<String, WeightedSpanTerm> terms = new PositionCheckingMap<>(); Map<String, WeightedSpanTerm> terms = new PositionCheckingMap<>();
@ -640,7 +614,6 @@ public class WeightedSpanTermExtractor {
* This class makes sure that if both position sensitive and insensitive versions of the same term * This class makes sure that if both position sensitive and insensitive versions of the same term
* are added, the position insensitive one wins. * are added, the position insensitive one wins.
*/ */
@SuppressWarnings("serial")
protected static class PositionCheckingMap<K> extends HashMap<K, WeightedSpanTerm> { protected static class PositionCheckingMap<K> extends HashMap<K, WeightedSpanTerm> {
@Override @Override
@ -650,15 +623,12 @@ public class WeightedSpanTermExtractor {
} }
@Override @Override
public WeightedSpanTerm put(K key, WeightedSpanTerm value) { public WeightedSpanTerm put(K key, WeightedSpanTerm newTerm) {
WeightedSpanTerm prev = super.put(key, value); WeightedSpanTerm prevTerm = super.put(key, newTerm);
if (prev == null) return prev; if (prevTerm != null && prevTerm.positionSensitive == false) {
WeightedSpanTerm prevTerm = prev;
WeightedSpanTerm newTerm = value;
if (!prevTerm.positionSensitive) {
newTerm.positionSensitive = false; newTerm.positionSensitive = false;
} }
return prev; return prevTerm;
} }
} }