mirror of https://github.com/apache/lucene.git
Make TermStates#build concurrent (#12183)
This commit is contained in:
parent
3deead0ed3
commit
fb1f4dd412
|
@ -148,7 +148,7 @@ Improvements
|
|||
|
||||
Optimizations
|
||||
---------------------
|
||||
(No changes)
|
||||
* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary)
|
||||
|
||||
Changes in runtime behavior
|
||||
---------------------
|
||||
|
|
|
@ -23,7 +23,6 @@ import org.apache.lucene.analysis.TokenStream;
|
|||
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
|
||||
import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute;
|
||||
import org.apache.lucene.index.IndexOptions;
|
||||
import org.apache.lucene.index.IndexReader;
|
||||
import org.apache.lucene.index.Term;
|
||||
import org.apache.lucene.index.TermStates;
|
||||
import org.apache.lucene.search.BooleanQuery;
|
||||
|
@ -345,7 +344,7 @@ public final class FeatureField extends Field {
|
|||
if (pivot != null) {
|
||||
return super.rewrite(indexSearcher);
|
||||
}
|
||||
float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature);
|
||||
float newPivot = computePivotFeatureValue(indexSearcher, field, feature);
|
||||
return new SaturationFunction(field, feature, newPivot);
|
||||
}
|
||||
|
||||
|
@ -618,14 +617,14 @@ public final class FeatureField extends Field {
|
|||
* store the exponent in the higher bits, it means that the result will be an approximation of the
|
||||
* geometric mean of all feature values.
|
||||
*
|
||||
* @param reader the {@link IndexReader} to search against
|
||||
* @param searcher the {@link IndexSearcher} to perform the search
|
||||
* @param featureField the field that stores features
|
||||
* @param featureName the name of the feature
|
||||
*/
|
||||
static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName)
|
||||
throws IOException {
|
||||
static float computePivotFeatureValue(
|
||||
IndexSearcher searcher, String featureField, String featureName) throws IOException {
|
||||
Term term = new Term(featureField, featureName);
|
||||
TermStates states = TermStates.build(reader.getContext(), term, true);
|
||||
TermStates states = TermStates.build(searcher, term, true);
|
||||
if (states.docFreq() == 0) {
|
||||
// avoid division by 0
|
||||
// The return value doesn't matter much here, the term doesn't exist,
|
||||
|
|
|
@ -18,6 +18,9 @@ package org.apache.lucene.index;
|
|||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.search.TaskExecutor;
|
||||
|
||||
/**
|
||||
* Maintains a {@link IndexReader} {@link TermState} view over {@link IndexReader} instances
|
||||
|
@ -86,19 +89,47 @@ public final class TermStates {
|
|||
* @param needsStats if {@code true} then all leaf contexts will be visited up-front to collect
|
||||
* term statistics. Otherwise, the {@link TermState} objects will be built only when requested
|
||||
*/
|
||||
public static TermStates build(IndexReaderContext context, Term term, boolean needsStats)
|
||||
public static TermStates build(IndexSearcher indexSearcher, Term term, boolean needsStats)
|
||||
throws IOException {
|
||||
assert context != null && context.isTopLevel;
|
||||
IndexReaderContext context = indexSearcher.getTopReaderContext();
|
||||
assert context != null;
|
||||
final TermStates perReaderTermState = new TermStates(needsStats ? null : term, context);
|
||||
if (needsStats) {
|
||||
for (final LeafReaderContext ctx : context.leaves()) {
|
||||
// if (DEBUG) System.out.println(" r=" + leaves[i].reader);
|
||||
TaskExecutor taskExecutor = indexSearcher.getTaskExecutor();
|
||||
if (taskExecutor != null) {
|
||||
// build the term states concurrently
|
||||
List<TaskExecutor.Task<TermStateInfo>> tasks =
|
||||
context.leaves().stream()
|
||||
.map(
|
||||
ctx ->
|
||||
taskExecutor.createTask(
|
||||
() -> {
|
||||
TermsEnum termsEnum = loadTermsEnum(ctx, term);
|
||||
if (termsEnum != null) {
|
||||
final TermState termState = termsEnum.termState();
|
||||
// if (DEBUG) System.out.println(" found");
|
||||
return new TermStateInfo(
|
||||
termsEnum.termState(),
|
||||
ctx.ord,
|
||||
termsEnum.docFreq(),
|
||||
termsEnum.totalTermFreq());
|
||||
}
|
||||
return null;
|
||||
}))
|
||||
.toList();
|
||||
List<TermStateInfo> resultInfos = taskExecutor.invokeAll(tasks);
|
||||
for (TermStateInfo info : resultInfos) {
|
||||
if (info != null) {
|
||||
perReaderTermState.register(
|
||||
termState, ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
|
||||
info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// build the term states sequentially
|
||||
for (final LeafReaderContext ctx : context.leaves()) {
|
||||
TermsEnum termsEnum = loadTermsEnum(ctx, term);
|
||||
if (termsEnum != null) {
|
||||
perReaderTermState.register(
|
||||
termsEnum.termState(), ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -211,4 +242,40 @@ public final class TermStates {
|
|||
|
||||
return sb.toString();
|
||||
}
|
||||
|
||||
/** Wrapper over TermState, ordinal value, term doc frequency and total term frequency */
|
||||
private static final class TermStateInfo {
|
||||
private final TermState state;
|
||||
private final int ordinal;
|
||||
private final int docFreq;
|
||||
private final long totalTermFreq;
|
||||
|
||||
/** Initialize TermStateInfo */
|
||||
public TermStateInfo(TermState state, int ordinal, int docFreq, long totalTermFreq) {
|
||||
this.state = state;
|
||||
this.ordinal = ordinal;
|
||||
this.docFreq = docFreq;
|
||||
this.totalTermFreq = totalTermFreq;
|
||||
}
|
||||
|
||||
/** Get term state */
|
||||
public TermState getState() {
|
||||
return state;
|
||||
}
|
||||
|
||||
/** Get ordinal value */
|
||||
public int getOrdinal() {
|
||||
return ordinal;
|
||||
}
|
||||
|
||||
/** Get term doc frequency */
|
||||
public int getDocFreq() {
|
||||
return docFreq;
|
||||
}
|
||||
|
||||
/** Get total term frequency */
|
||||
public long getTotalTermFreq() {
|
||||
return totalTermFreq;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -272,7 +272,7 @@ public final class BlendedTermQuery extends Query {
|
|||
for (int i = 0; i < contexts.length; ++i) {
|
||||
if (contexts[i] == null
|
||||
|| contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) {
|
||||
contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true);
|
||||
contexts[i] = TermStates.build(indexSearcher, terms[i], true);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -24,7 +24,6 @@ import java.util.HashMap;
|
|||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
import org.apache.lucene.index.IndexReaderContext;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
|
@ -219,7 +218,6 @@ public class MultiPhraseQuery extends Query {
|
|||
|
||||
@Override
|
||||
protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOException {
|
||||
final IndexReaderContext context = searcher.getTopReaderContext();
|
||||
|
||||
// compute idf
|
||||
ArrayList<TermStatistics> allTermStats = new ArrayList<>();
|
||||
|
@ -227,7 +225,7 @@ public class MultiPhraseQuery extends Query {
|
|||
for (Term term : terms) {
|
||||
TermStates ts = termStates.get(term);
|
||||
if (ts == null) {
|
||||
ts = TermStates.build(context, term, scoreMode.needsScores());
|
||||
ts = TermStates.build(searcher, term, scoreMode.needsScores());
|
||||
termStates.put(term, ts);
|
||||
}
|
||||
if (scoreMode.needsScores() && ts.docFreq() > 0) {
|
||||
|
|
|
@ -24,7 +24,6 @@ import java.util.Objects;
|
|||
import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat;
|
||||
import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader;
|
||||
import org.apache.lucene.index.ImpactsEnum;
|
||||
import org.apache.lucene.index.IndexReaderContext;
|
||||
import org.apache.lucene.index.LeafReader;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
|
@ -451,13 +450,12 @@ public class PhraseQuery extends Query {
|
|||
throw new IllegalStateException(
|
||||
"PhraseWeight requires that the first position is 0, call rewrite first");
|
||||
}
|
||||
final IndexReaderContext context = searcher.getTopReaderContext();
|
||||
states = new TermStates[terms.length];
|
||||
TermStatistics[] termStats = new TermStatistics[terms.length];
|
||||
int termUpTo = 0;
|
||||
for (int i = 0; i < terms.length; i++) {
|
||||
final Term term = terms[i];
|
||||
states[i] = TermStates.build(context, term, scoreMode.needsScores());
|
||||
states[i] = TermStates.build(searcher, term, scoreMode.needsScores());
|
||||
if (scoreMode.needsScores()) {
|
||||
TermStates ts = states[i];
|
||||
if (ts.docFreq() > 0) {
|
||||
|
|
|
@ -207,7 +207,7 @@ public final class SynonymQuery extends Query {
|
|||
termStates = new TermStates[terms.length];
|
||||
for (int i = 0; i < termStates.length; i++) {
|
||||
Term term = new Term(field, terms[i].term);
|
||||
TermStates ts = TermStates.build(searcher.getTopReaderContext(), term, true);
|
||||
TermStates ts = TermStates.build(searcher, term, true);
|
||||
termStates[i] = ts;
|
||||
if (ts.docFreq() > 0) {
|
||||
TermStatistics termStats =
|
||||
|
|
|
@ -272,7 +272,7 @@ public class TermQuery extends Query {
|
|||
final IndexReaderContext context = searcher.getTopReaderContext();
|
||||
final TermStates termState;
|
||||
if (perReaderTermState == null || perReaderTermState.wasBuiltFor(context) == false) {
|
||||
termState = TermStates.build(context, term, scoreMode.needsScores());
|
||||
termState = TermStates.build(searcher, term, scoreMode.needsScores());
|
||||
} else {
|
||||
// PRTS was pre-build for this IS
|
||||
termState = this.perReaderTermState;
|
||||
|
|
|
@ -272,7 +272,8 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
|
||||
// Make sure that we create a legal pivot on missing features
|
||||
DirectoryReader reader = writer.getReader();
|
||||
float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||
IndexSearcher searcher = LuceneTestCase.newSearcher(reader);
|
||||
float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
|
||||
assertTrue(Float.isFinite(pivot));
|
||||
assertTrue(pivot > 0);
|
||||
reader.close();
|
||||
|
@ -298,7 +299,8 @@ public class TestFeatureField extends LuceneTestCase {
|
|||
reader = writer.getReader();
|
||||
writer.close();
|
||||
|
||||
pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank");
|
||||
searcher = LuceneTestCase.newSearcher(reader);
|
||||
pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank");
|
||||
double expected = Math.pow(10 * 100 * 1 * 42, 1 / 4.); // geometric mean
|
||||
assertEquals(expected, pivot, 0.1);
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.lucene.index;
|
||||
|
||||
import org.apache.lucene.document.Document;
|
||||
import org.apache.lucene.search.IndexSearcher;
|
||||
import org.apache.lucene.store.Directory;
|
||||
import org.apache.lucene.tests.index.RandomIndexWriter;
|
||||
import org.apache.lucene.tests.util.LuceneTestCase;
|
||||
|
@ -30,8 +31,8 @@ public class TestTermStates extends LuceneTestCase {
|
|||
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
|
||||
w.addDocument(new Document());
|
||||
IndexReader r = w.getReader();
|
||||
TermStates states =
|
||||
TermStates.build(r.getContext(), new Term("foo", "bar"), random().nextBoolean());
|
||||
IndexSearcher s = new IndexSearcher(r);
|
||||
TermStates states = TermStates.build(s, new Term("foo", "bar"), random().nextBoolean());
|
||||
assertEquals("TermStates\n state=null\n", states.toString());
|
||||
IOUtils.close(r, w, dir);
|
||||
}
|
||||
|
|
|
@ -365,7 +365,7 @@ public class TestMinShouldMatch2 extends LuceneTestCase {
|
|||
if (ord >= 0) {
|
||||
boolean success = ords.add(ord);
|
||||
assert success; // no dups
|
||||
TermStates ts = TermStates.build(reader.getContext(), term, true);
|
||||
TermStates ts = TermStates.build(searcher, term, true);
|
||||
SimScorer w =
|
||||
weight.similarity.scorer(
|
||||
1f,
|
||||
|
|
|
@ -55,11 +55,12 @@ public class TestTermQuery extends LuceneTestCase {
|
|||
final CompositeReaderContext context;
|
||||
try (MultiReader multiReader = new MultiReader()) {
|
||||
context = multiReader.getContext();
|
||||
}
|
||||
IndexSearcher searcher = new IndexSearcher(context);
|
||||
QueryUtils.checkEqual(
|
||||
new TermQuery(new Term("foo", "bar")),
|
||||
new TermQuery(
|
||||
new Term("foo", "bar"), TermStates.build(context, new Term("foo", "bar"), true)));
|
||||
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true)));
|
||||
}
|
||||
}
|
||||
|
||||
public void testCreateWeightDoesNotSeekIfScoresAreNotNeeded() throws IOException {
|
||||
|
@ -100,8 +101,7 @@ public class TestTermQuery extends LuceneTestCase {
|
|||
assertEquals(1, totalHits);
|
||||
TermQuery queryWithContext =
|
||||
new TermQuery(
|
||||
new Term("foo", "bar"),
|
||||
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
|
||||
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
|
||||
totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager());
|
||||
assertEquals(1, totalHits);
|
||||
|
||||
|
@ -160,10 +160,10 @@ public class TestTermQuery extends LuceneTestCase {
|
|||
w.addDocument(new Document());
|
||||
|
||||
DirectoryReader reader = w.getReader();
|
||||
IndexSearcher searcher = new IndexSearcher(reader);
|
||||
TermQuery queryWithContext =
|
||||
new TermQuery(
|
||||
new Term("foo", "bar"),
|
||||
TermStates.build(reader.getContext(), new Term("foo", "bar"), true));
|
||||
new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true));
|
||||
assertNotNull(queryWithContext.getTermStates());
|
||||
IOUtils.close(reader, w, dir);
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ public class SpanTermQuery extends SpanQuery {
|
|||
final TermStates context;
|
||||
final IndexReaderContext topContext = searcher.getTopReaderContext();
|
||||
if (termStates == null || termStates.wasBuiltFor(topContext) == false) {
|
||||
context = TermStates.build(topContext, term, scoreMode.needsScores());
|
||||
context = TermStates.build(searcher, term, scoreMode.needsScores());
|
||||
} else {
|
||||
context = termStates;
|
||||
}
|
||||
|
|
|
@ -330,7 +330,7 @@ public final class CombinedFieldQuery extends Query implements Accountable {
|
|||
termStates = new TermStates[fieldTerms.length];
|
||||
for (int i = 0; i < termStates.length; i++) {
|
||||
FieldAndWeight field = fieldAndWeights.get(fieldTerms[i].field());
|
||||
TermStates ts = TermStates.build(searcher.getTopReaderContext(), fieldTerms[i], true);
|
||||
TermStates ts = TermStates.build(searcher, fieldTerms[i], true);
|
||||
termStates[i] = ts;
|
||||
if (ts.docFreq() > 0) {
|
||||
TermStatistics termStats =
|
||||
|
|
|
@ -375,7 +375,7 @@ public class PhraseWildcardQuery extends Query {
|
|||
TermData termData = termsData.getOrCreateTermData(singleTerm.termPosition);
|
||||
Term term = singleTerm.term;
|
||||
termData.terms.add(term);
|
||||
TermStates termStates = TermStates.build(searcher.getIndexReader().getContext(), term, true);
|
||||
TermStates termStates = TermStates.build(searcher, term, true);
|
||||
|
||||
// Collect TermState per segment.
|
||||
int numMatches = 0;
|
||||
|
|
|
@ -23,7 +23,6 @@ import java.util.ArrayList;
|
|||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import org.apache.lucene.index.IndexReaderContext;
|
||||
import org.apache.lucene.index.LeafReaderContext;
|
||||
import org.apache.lucene.index.PostingsEnum;
|
||||
import org.apache.lucene.index.ReaderUtil;
|
||||
|
@ -209,14 +208,13 @@ public class TermAutomatonQuery extends Query implements Accountable {
|
|||
@Override
|
||||
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost)
|
||||
throws IOException {
|
||||
IndexReaderContext context = searcher.getTopReaderContext();
|
||||
Map<Integer, TermStates> termStates = new HashMap<>();
|
||||
|
||||
for (Map.Entry<BytesRef, Integer> ent : termToID.entrySet()) {
|
||||
if (ent.getKey() != null) {
|
||||
termStates.put(
|
||||
ent.getValue(),
|
||||
TermStates.build(context, new Term(field, ent.getKey()), scoreMode.needsScores()));
|
||||
TermStates.build(searcher, new Term(field, ent.getKey()), scoreMode.needsScores()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -207,7 +207,7 @@ public abstract class ShardSearchingTestBase extends LuceneTestCase {
|
|||
}
|
||||
try {
|
||||
for (Term term : terms) {
|
||||
final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true);
|
||||
final TermStates ts = TermStates.build(s, term, true);
|
||||
if (ts.docFreq() > 0) {
|
||||
stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq()));
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue