Make TermStates#build concurrent (#12183)

This commit is contained in:
Shubham Chaudhary 2023-09-21 20:47:36 +05:30 committed by GitHub
parent 3deead0ed3
commit fb1f4dd412
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 109 additions and 46 deletions

View File

@ -148,7 +148,7 @@ Improvements
Optimizations
---------------------
(No changes)
* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary)
Changes in runtime behavior
---------------------

View File

@ -23,7 +23,6 @@ import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
import org.apache.lucene.index.IndexOptions;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermStates;
import org.apache.lucene.search.BooleanQuery;
@ -345,7 +344,7 @@ public final class FeatureField extends Field {
if (pivot != null) {
return super.rewrite(indexSearcher);
}
float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature);
float newPivot = computePivotFeatureValue(indexSearcher, field, feature);
return new SaturationFunction(field, feature, newPivot);
}
@ -618,14 +617,14 @@ public final class FeatureField extends Field {
* store the exponent in the higher bits, it means that the result will be an approximation of the
* geometric mean of all feature values.
*
* @param reader the {@link IndexReader} to search against
* @param searcher the {@link IndexSearcher} to perform the search
* @param featureField the field that stores features
* @param featureName the name of the feature
*/
static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName)
throws IOException {
static float computePivotFeatureValue(
IndexSearcher searcher, String featureField, String featureName) throws IOException {
Term term = new Term(featureField, featureName);
TermStates states = TermStates.build(reader.getContext(), term, true);
TermStates states = TermStates.build(searcher, term, true);
if (states.docFreq() == 0) {
// avoid division by 0
// The return value doesn't matter much here, the term doesn't exist,

View File

@ -18,6 +18,9 @@ package org.apache.lucene.index;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.TaskExecutor;
/**
* Maintains a {@link IndexReader} {@link TermState} view over {@link IndexReader} instances
@ -86,19 +89,47 @@ public final class TermStates {
* @param needsStats if {@code true} then all leaf contexts will be visited up-front to collect
* term statistics. Otherwise, the {@link TermState} objects will be built only when requested
*/
public static TermStates build(IndexReaderContext context, Term term, boolean needsStats)
public static TermStates build(IndexSearcher indexSearcher, Term term, boolean needsStats)
throws IOException {
assert context != null && context.isTopLevel;
IndexReaderContext context = indexSearcher.getTopReaderContext();
assert context != null;
final TermStates perReaderTermState = new TermStates(needsStats ? null : term, context);
if (needsStats) {
for (final LeafReaderContext ctx : context.leaves()) {
// if (DEBUG) System.out.println(" r=" + leaves[i].reader);
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
if (taskExecutor != null) {
// build the term states concurrently
List<TaskExecutor.Task<TermStateInfo>> tasks =
context.leaves().stream()
.map(
ctx ->
taskExecutor.createTask(
() -> {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
final TermState termState = termsEnum.termState();
// if (DEBUG) System.out.println(" found");
return new TermStateInfo(
termsEnum.termState(),
ctx.ord,
termsEnum.docFreq(),
termsEnum.totalTermFreq());
}
return null;
}))
.toList();
List<TermStateInfo> resultInfos = taskExecutor.invokeAll(tasks);
for (TermStateInfo info : resultInfos) {
if (info != null) {
perReaderTermState.register(
termState, ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq());
}
}
} else {
// build the term states sequentially
for (final LeafReaderContext ctx : context.leaves()) {
TermsEnum termsEnum = loadTermsEnum(ctx, term);
if (termsEnum != null) {
perReaderTermState.register(
termsEnum.termState(), ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
}
}
}
}
@ -211,4 +242,40 @@ public final class TermStates {
return sb.toString();
}
/** Wrapper over TermState, ordinal value, term doc frequency and total term frequency */
private static final class TermStateInfo {
private final TermState state;
private final int ordinal;
private final int docFreq;
private final long totalTermFreq;
/** Initialize TermStateInfo */
public TermStateInfo(TermState state, int ordinal, int docFreq, long totalTermFreq) {
this.state = state;
this.ordinal = ordinal;
this.docFreq = docFreq;
this.totalTermFreq = totalTermFreq;
}
/** Get term state */
public TermState getState() {
return state;
}
/** Get ordinal value */
public int getOrdinal() {
return ordinal;
}
/** Get term doc frequency */
public int getDocFreq() {
return docFreq;
}
/** Get total term frequency */
public long getTotalTermFreq() {
return totalTermFreq;
}
}
}

View File

@ -272,7 +272,7 @@ public final class BlendedTermQuery extends Query {
for (int i = 0; i < contexts.length; ++i) {
if (contexts[i] == null
|| contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) {
contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true);
contexts[i] = TermStates.build(indexSearcher, terms[i], true);
}
}

View File

@ -24,7 +24,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
@ -219,7 +218,6 @@ public class MultiPhraseQuery extends Query {
@Override
protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOException {
final IndexReaderContext context = searcher.getTopReaderContext();
// compute idf
ArrayList<TermStatistics> allTermStats = new ArrayList<>();
@ -227,7 +225,7 @@ public class MultiPhraseQuery extends Query {
for (Term term : terms) {
TermStates ts = termStates.get(term);
if (ts == null) {
ts = TermStates.build(context, term, scoreMode.needsScores());
ts = TermStates.build(searcher, term, scoreMode.needsScores());
termStates.put(term, ts);
}
if (scoreMode.needsScores() && ts.docFreq() > 0) {

View File

@ -24,7 +24,6 @@ import java.util.Objects;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader;
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
@ -451,13 +450,12 @@ public class PhraseQuery extends Query {
throw new IllegalStateException(
"PhraseWeight requires that the first position is 0, call rewrite first");
}
final IndexReaderContext context = searcher.getTopReaderContext();
states = new TermStates[terms.length];
TermStatistics[] termStats = new TermStatistics[terms.length];
int termUpTo = 0;
for (int i = 0; i < terms.length; i++) {
final Term term = terms[i];
states[i] = TermStates.build(context, term, scoreMode.needsScores());
states[i] = TermStates.build(searcher, term, scoreMode.needsScores());
if (scoreMode.needsScores()) {
TermStates ts = states[i];
if (ts.docFreq() > 0) {

View File

@ -207,7 +207,7 @@ public final class SynonymQuery extends Query {
termStates = new TermStates[terms.length];
for (int i = 0; i < termStates.length; i++) {
Term term = new Term(field, terms[i].term);
TermStates ts = TermStates.build(searcher.getTopReaderContext(), term, true);
TermStates ts = TermStates.build(searcher, term, true);
termStates[i] = ts;
if (ts.docFreq() > 0) {
TermStatistics termStats =

View File

@ -272,7 +272,7 @@ public class TermQuery extends Query {
final IndexReaderContext context = searcher.getTopReaderContext();
final TermStates termState;
if (perReaderTermState == null || perReaderTermState.wasBuiltFor(context) == false) {
termState = TermStates.build(context, term, scoreMode.needsScores());
termState = TermStates.build(searcher, term, scoreMode.needsScores());
} else {
// PRTS was pre-build for this IS
termState = this.perReaderTermState;

View File

@ -272,7 +272,8 @@ public class TestFeatureField extends LuceneTestCase {
// Make sure that we create a legal pivot on missing features
DirectoryReader reader = writer.getReader();
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
IndexSearcher searcher = LuceneTestCase.newSearcher(reader);
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
assertTrue(Float.isFinite(pivot));
assertTrue(pivot > 0);
reader.close();
@ -298,7 +299,8 @@ public class TestFeatureField extends LuceneTestCase {
reader = writer.getReader();
writer.close();
pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
searcher = LuceneTestCase.newSearcher(reader);
pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
double expected = Math.pow(10 * 100 * 1 * 42, 1 / 4.); // geometric mean
assertEquals(expected, pivot, 0.1);

View File

@ -18,6 +18,7 @@
package org.apache.lucene.index;
import org.apache.lucene.document.Document;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.util.LuceneTestCase;
@ -30,8 +31,8 @@ public class TestTermStates extends LuceneTestCase {
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
w.addDocument(new Document());
IndexReader r = w.getReader();
TermStates states =
TermStates.build(r.getContext(), new Term("foo", "bar"), random().nextBoolean());
IndexSearcher s = new IndexSearcher(r);
TermStates states = TermStates.build(s, new Term("foo", "bar"), random().nextBoolean());
assertEquals("TermStates\n state=null\n", states.toString());
IOUtils.close(r, w, dir);
}

View File

@ -365,7 +365,7 @@ public class TestMinShouldMatch2 extends LuceneTestCase {
if (ord >= 0) {
boolean success = ords.add(ord);
assert success; // no dups
TermStates ts = TermStates.build(reader.getContext(), term, true);
TermStates ts = TermStates.build(searcher, term, true);
SimScorer w =
weight.similarity.scorer(
1f,

View File

@ -55,11 +55,12 @@ public class TestTermQuery extends LuceneTestCase {
final CompositeReaderContext context;
try (MultiReader multiReader = new MultiReader()) {
context = multiReader.getContext();
}
IndexSearcher searcher = new IndexSearcher(context);
QueryUtils.checkEqual(
new TermQuery(new Term("foo", "bar")),
new TermQuery(
new Term("foo", "bar"), TermStates.build(context, new Term("foo", "bar"), true)));
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true)));
}
}
public void testCreateWeightDoesNotSeekIfScoresAreNotNeeded() throws IOException {
@ -100,8 +101,7 @@ public class TestTermQuery extends LuceneTestCase {
assertEquals(1, totalHits);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
assertEquals(1, totalHits);
@ -160,10 +160,10 @@ public class TestTermQuery extends LuceneTestCase {
w.addDocument(new Document());
DirectoryReader reader = w.getReader();
IndexSearcher searcher = new IndexSearcher(reader);
TermQuery queryWithContext =
new TermQuery(
new Term("foo", "bar"),
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
assertNotNull(queryWithContext.getTermStates());
IOUtils.close(reader, w, dir);
}

View File

@ -82,7 +82,7 @@ public class SpanTermQuery extends SpanQuery {
final TermStates context;
final IndexReaderContext topContext = searcher.getTopReaderContext();
if (termStates == null || termStates.wasBuiltFor(topContext) == false) {
context = TermStates.build(topContext, term, scoreMode.needsScores());
context = TermStates.build(searcher, term, scoreMode.needsScores());
} else {
context = termStates;
}

View File

@ -330,7 +330,7 @@ public final class CombinedFieldQuery extends Query implements Accountable {
termStates = new TermStates[fieldTerms.length];
for (int i = 0; i < termStates.length; i++) {
FieldAndWeight field = fieldAndWeights.get(fieldTerms[i].field());
TermStates ts = TermStates.build(searcher.getTopReaderContext(), fieldTerms[i], true);
TermStates ts = TermStates.build(searcher, fieldTerms[i], true);
termStates[i] = ts;
if (ts.docFreq() > 0) {
TermStatistics termStats =

View File

@ -375,7 +375,7 @@ public class PhraseWildcardQuery extends Query {
TermData termData = termsData.getOrCreateTermData(singleTerm.termPosition);
Term term = singleTerm.term;
termData.terms.add(term);
TermStates termStates = TermStates.build(searcher.getIndexReader().getContext(), term, true);
TermStates termStates = TermStates.build(searcher, term, true);
// Collect TermState per segment.
int numMatches = 0;

View File

@ -23,7 +23,6 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.lucene.index.IndexReaderContext;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.ReaderUtil;
@ -209,14 +208,13 @@ public class TermAutomatonQuery extends Query implements Accountable {
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
throws IOException {
IndexReaderContext context = searcher.getTopReaderContext();
Map<Integer, TermStates> termStates = new HashMap<>();
for (Map.Entry<BytesRef, Integer> ent : termToID.entrySet()) {
if (ent.getKey() != null) {
termStates.put(
ent.getValue(),
TermStates.build(context, new Term(field, ent.getKey()), scoreMode.needsScores()));
TermStates.build(searcher, new Term(field, ent.getKey()), scoreMode.needsScores()));
}
}

View File

@ -207,7 +207,7 @@ public abstract class ShardSearchingTestBase extends LuceneTestCase {
}
try {
for (Term term : terms) {
final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true);
final TermStates ts = TermStates.build(s, term, true);
if (ts.docFreq() > 0) {
stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq()));
}