From a32f6acadf69693855b31fced2678ac9c014895f Mon Sep 17 00:00:00 2001 From: sabi0 <2sabio@gmail.com> Date: Mon, 8 Jan 2024 22:01:56 +0100 Subject: [PATCH] Remove unnecessary fields loop from extractWeightedSpanTerms() (#12965) --- .../highlight/WeightedSpanTermExtractor.java | 116 +++++++----------- 1 file changed, 43 insertions(+), 73 deletions(-) diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java index 9fd09f0eb4d..6c67dc98d56 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/WeightedSpanTermExtractor.java @@ -100,8 +100,8 @@ import org.apache.lucene.util.IOUtils; public class WeightedSpanTermExtractor { private String fieldName; - private TokenStream tokenStream; // set subsequent to getWeightedSpanTerms* methods - private String defaultField; + private TokenStream tokenStream; // set after getWeightedSpanTerms* methods + private final String defaultField; private boolean expandMultiTermQuery; private boolean cachedTokenStream; private boolean wrapToCaching = true; @@ -244,7 +244,6 @@ public class WeightedSpanTermExtractor { && (!expandMultiTermQuery || !fieldNameComparator(((MultiTermQuery) query).getField()))) { return; } - Query origQuery = query; final IndexReader reader = getLeafContext().reader(); Query rewritten; if (query instanceof MultiTermQuery) { @@ -252,12 +251,11 @@ public class WeightedSpanTermExtractor { MultiTermQuery.SCORING_BOOLEAN_REWRITE.rewrite( new IndexSearcher(reader), (MultiTermQuery) query); } else { - rewritten = origQuery.rewrite(new IndexSearcher(reader)); + rewritten = query.rewrite(new IndexSearcher(reader)); } - if (rewritten != origQuery) { + if (rewritten != query) { // only rewrite once and then flatten again - the rewritten query could have a special - // treatment - // if this method is overwritten in a subclass or above in the next recursion + // treatment if this method is overwritten in a subclass or above in the next recursion extract(rewritten, boost, terms); } else { extractUnknownQuery(query, terms); @@ -293,67 +291,49 @@ public class WeightedSpanTermExtractor { */ protected void extractWeightedSpanTerms( Map terms, SpanQuery spanQuery, float boost) throws IOException { - Set fieldNames; - if (fieldName == null) { - fieldNames = new HashSet<>(); - collectSpanQueryFields(spanQuery, fieldNames); - } else { - fieldNames = new HashSet<>(1); - fieldNames.add(fieldName); - } - // To support the use of the default field name - if (defaultField != null) { - fieldNames.add(defaultField); + Set queryFieldNames = new HashSet<>(); + collectSpanQueryFields(spanQuery, queryFieldNames); + if (fieldName != null + && queryFieldNames.contains(fieldName) == false + && (defaultField == null || queryFieldNames.contains(defaultField) == false)) { + return; } - Map queries = new HashMap<>(); - - Set nonWeightedTerms = new HashSet<>(); final boolean mustRewriteQuery = mustRewriteQuery(spanQuery); final IndexSearcher searcher = new IndexSearcher(getLeafContext()); searcher.setQueryCache(null); - if (mustRewriteQuery) { - final SpanQuery rewrittenQuery = (SpanQuery) searcher.rewrite(spanQuery); - for (final String field : fieldNames) { - queries.put(field, rewrittenQuery); - } - rewrittenQuery.visit(QueryVisitor.termCollector(nonWeightedTerms)); - } else { - spanQuery.visit(QueryVisitor.termCollector(nonWeightedTerms)); + final SpanQuery query = mustRewriteQuery ? (SpanQuery) searcher.rewrite(spanQuery) : spanQuery; + + final Set nonWeightedTerms = new HashSet<>(); + query.visit(QueryVisitor.termCollector(nonWeightedTerms)); + if (nonWeightedTerms.isEmpty()) { + return; } - List spanPositions = new ArrayList<>(); + final List spanPositions = new ArrayList<>(); - for (final String field : fieldNames) { - final SpanQuery q; - if (mustRewriteQuery) { - q = queries.get(field); - } else { - q = spanQuery; - } - LeafReaderContext context = getLeafContext(); - SpanWeight w = - (SpanWeight) searcher.createWeight(searcher.rewrite(q), ScoreMode.COMPLETE_NO_SCORES, 1); - Bits acceptDocs = context.reader().getLiveDocs(); - final Spans spans = w.getSpans(context, SpanWeight.Postings.POSITIONS); - if (spans == null) { - return; - } + LeafReaderContext context = getLeafContext(); + SpanWeight w = + (SpanWeight) + searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1); + final Spans spans = w.getSpans(context, SpanWeight.Postings.POSITIONS); + if (spans == null) { + return; + } - // collect span positions - while (spans.nextDoc() != Spans.NO_MORE_DOCS) { - if (acceptDocs != null && acceptDocs.get(spans.docID()) == false) { - continue; - } - while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) { - spanPositions.add(new PositionSpan(spans.startPosition(), spans.endPosition() - 1)); - } + final Bits acceptDocs = context.reader().getLiveDocs(); + // collect span positions + while (spans.nextDoc() != Spans.NO_MORE_DOCS) { + if (acceptDocs != null && acceptDocs.get(spans.docID()) == false) { + continue; + } + while (spans.nextStartPosition() != Spans.NO_MORE_POSITIONS) { + spanPositions.add(new PositionSpan(spans.startPosition(), spans.endPosition() - 1)); } } - if (spanPositions.size() == 0) { - // no spans found + if (spanPositions.isEmpty()) { return; } @@ -401,11 +381,9 @@ public class WeightedSpanTermExtractor { /** Necessary to implement matches for queries against defaultField */ protected boolean fieldNameComparator(String fieldNameToCheck) { - boolean rv = - fieldName == null - || fieldName.equals(fieldNameToCheck) - || (defaultField != null && defaultField.equals(fieldNameToCheck)); - return rv; + return fieldName == null + || fieldName.equals(fieldNameToCheck) + || (defaultField != null && defaultField.equals(fieldNameToCheck)); } protected LeafReaderContext getLeafContext() throws IOException { @@ -555,11 +533,7 @@ public class WeightedSpanTermExtractor { public Map getWeightedSpanTermsWithScores( Query query, float boost, TokenStream tokenStream, String fieldName, IndexReader reader) throws IOException { - if (fieldName != null) { - this.fieldName = fieldName; - } else { - this.fieldName = null; - } + this.fieldName = fieldName; this.tokenStream = tokenStream; Map terms = new PositionCheckingMap<>(); @@ -640,7 +614,6 @@ public class WeightedSpanTermExtractor { * This class makes sure that if both position sensitive and insensitive versions of the same term * are added, the position insensitive one wins. */ - @SuppressWarnings("serial") protected static class PositionCheckingMap extends HashMap { @Override @@ -650,15 +623,12 @@ public class WeightedSpanTermExtractor { } @Override - public WeightedSpanTerm put(K key, WeightedSpanTerm value) { - WeightedSpanTerm prev = super.put(key, value); - if (prev == null) return prev; - WeightedSpanTerm prevTerm = prev; - WeightedSpanTerm newTerm = value; - if (!prevTerm.positionSensitive) { + public WeightedSpanTerm put(K key, WeightedSpanTerm newTerm) { + WeightedSpanTerm prevTerm = super.put(key, newTerm); + if (prevTerm != null && prevTerm.positionSensitive == false) { newTerm.positionSensitive = false; } - return prev; + return prevTerm; } }