From fb1f4dd412db2c415e945def532459dc899df199 Mon Sep 17 00:00:00 2001 From: Shubham Chaudhary <36742242+shubhamvishu@users.noreply.github.com> Date: Thu, 21 Sep 2023 20:47:36 +0530 Subject: [PATCH] Make TermStates#build concurrent (#12183) --- lucene/CHANGES.txt | 2 +- .../apache/lucene/document/FeatureField.java | 11 ++- .../org/apache/lucene/index/TermStates.java | 87 ++++++++++++++++--- .../lucene/search/BlendedTermQuery.java | 2 +- .../lucene/search/MultiPhraseQuery.java | 4 +- .../org/apache/lucene/search/PhraseQuery.java | 4 +- .../apache/lucene/search/SynonymQuery.java | 2 +- .../org/apache/lucene/search/TermQuery.java | 2 +- .../lucene/document/TestFeatureField.java | 6 +- .../apache/lucene/index/TestTermStates.java | 5 +- .../lucene/search/TestMinShouldMatch2.java | 2 +- .../apache/lucene/search/TestTermQuery.java | 16 ++-- .../lucene/queries/spans/SpanTermQuery.java | 2 +- .../sandbox/search/CombinedFieldQuery.java | 2 +- .../sandbox/search/PhraseWildcardQuery.java | 2 +- .../sandbox/search/TermAutomatonQuery.java | 4 +- .../tests/search/ShardSearchingTestBase.java | 2 +- 17 files changed, 109 insertions(+), 46 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 61f44825060..fc1ef72507e 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -148,7 +148,7 @@ Improvements Optimizations --------------------- -(No changes) +* GITHUB#12183: Make TermStates#build concurrent. (Shubham Chaudhary) Changes in runtime behavior --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java index 85c3bad3519..53edaad0eac 100644 --- a/lucene/core/src/java/org/apache/lucene/document/FeatureField.java +++ b/lucene/core/src/java/org/apache/lucene/document/FeatureField.java @@ -23,7 +23,6 @@ import org.apache.lucene.analysis.TokenStream; import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; import org.apache.lucene.analysis.tokenattributes.TermFrequencyAttribute; import org.apache.lucene.index.IndexOptions; -import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.Term; import org.apache.lucene.index.TermStates; import org.apache.lucene.search.BooleanQuery; @@ -345,7 +344,7 @@ public final class FeatureField extends Field { if (pivot != null) { return super.rewrite(indexSearcher); } - float newPivot = computePivotFeatureValue(indexSearcher.getIndexReader(), field, feature); + float newPivot = computePivotFeatureValue(indexSearcher, field, feature); return new SaturationFunction(field, feature, newPivot); } @@ -618,14 +617,14 @@ public final class FeatureField extends Field { * store the exponent in the higher bits, it means that the result will be an approximation of the * geometric mean of all feature values. * - * @param reader the {@link IndexReader} to search against + * @param searcher the {@link IndexSearcher} to perform the search * @param featureField the field that stores features * @param featureName the name of the feature */ - static float computePivotFeatureValue(IndexReader reader, String featureField, String featureName) - throws IOException { + static float computePivotFeatureValue( + IndexSearcher searcher, String featureField, String featureName) throws IOException { Term term = new Term(featureField, featureName); - TermStates states = TermStates.build(reader.getContext(), term, true); + TermStates states = TermStates.build(searcher, term, true); if (states.docFreq() == 0) { // avoid division by 0 // The return value doesn't matter much here, the term doesn't exist, diff --git a/lucene/core/src/java/org/apache/lucene/index/TermStates.java b/lucene/core/src/java/org/apache/lucene/index/TermStates.java index bf4b97e86e5..a8472bac0d0 100644 --- a/lucene/core/src/java/org/apache/lucene/index/TermStates.java +++ b/lucene/core/src/java/org/apache/lucene/index/TermStates.java @@ -18,6 +18,9 @@ package org.apache.lucene.index; import java.io.IOException; import java.util.Arrays; +import java.util.List; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.TaskExecutor; /** * Maintains a {@link IndexReader} {@link TermState} view over {@link IndexReader} instances @@ -86,19 +89,47 @@ public final class TermStates { * @param needsStats if {@code true} then all leaf contexts will be visited up-front to collect * term statistics. Otherwise, the {@link TermState} objects will be built only when requested */ - public static TermStates build(IndexReaderContext context, Term term, boolean needsStats) + public static TermStates build(IndexSearcher indexSearcher, Term term, boolean needsStats) throws IOException { - assert context != null && context.isTopLevel; + IndexReaderContext context = indexSearcher.getTopReaderContext(); + assert context != null; final TermStates perReaderTermState = new TermStates(needsStats ? null : term, context); if (needsStats) { - for (final LeafReaderContext ctx : context.leaves()) { - // if (DEBUG) System.out.println(" r=" + leaves[i].reader); - TermsEnum termsEnum = loadTermsEnum(ctx, term); - if (termsEnum != null) { - final TermState termState = termsEnum.termState(); - // if (DEBUG) System.out.println(" found"); - perReaderTermState.register( - termState, ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq()); + TaskExecutor taskExecutor = indexSearcher.getTaskExecutor(); + if (taskExecutor != null) { + // build the term states concurrently + List> tasks = + context.leaves().stream() + .map( + ctx -> + taskExecutor.createTask( + () -> { + TermsEnum termsEnum = loadTermsEnum(ctx, term); + if (termsEnum != null) { + return new TermStateInfo( + termsEnum.termState(), + ctx.ord, + termsEnum.docFreq(), + termsEnum.totalTermFreq()); + } + return null; + })) + .toList(); + List resultInfos = taskExecutor.invokeAll(tasks); + for (TermStateInfo info : resultInfos) { + if (info != null) { + perReaderTermState.register( + info.getState(), info.getOrdinal(), info.getDocFreq(), info.getTotalTermFreq()); + } + } + } else { + // build the term states sequentially + for (final LeafReaderContext ctx : context.leaves()) { + TermsEnum termsEnum = loadTermsEnum(ctx, term); + if (termsEnum != null) { + perReaderTermState.register( + termsEnum.termState(), ctx.ord, termsEnum.docFreq(), termsEnum.totalTermFreq()); + } } } } @@ -211,4 +242,40 @@ public final class TermStates { return sb.toString(); } + + /** Wrapper over TermState, ordinal value, term doc frequency and total term frequency */ + private static final class TermStateInfo { + private final TermState state; + private final int ordinal; + private final int docFreq; + private final long totalTermFreq; + + /** Initialize TermStateInfo */ + public TermStateInfo(TermState state, int ordinal, int docFreq, long totalTermFreq) { + this.state = state; + this.ordinal = ordinal; + this.docFreq = docFreq; + this.totalTermFreq = totalTermFreq; + } + + /** Get term state */ + public TermState getState() { + return state; + } + + /** Get ordinal value */ + public int getOrdinal() { + return ordinal; + } + + /** Get term doc frequency */ + public int getDocFreq() { + return docFreq; + } + + /** Get total term frequency */ + public long getTotalTermFreq() { + return totalTermFreq; + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java index 2c7e41ac971..05d86819486 100644 --- a/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/BlendedTermQuery.java @@ -272,7 +272,7 @@ public final class BlendedTermQuery extends Query { for (int i = 0; i < contexts.length; ++i) { if (contexts[i] == null || contexts[i].wasBuiltFor(indexSearcher.getTopReaderContext()) == false) { - contexts[i] = TermStates.build(indexSearcher.getTopReaderContext(), terms[i], true); + contexts[i] = TermStates.build(indexSearcher, terms[i], true); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java index 27819235f64..23cd178bfed 100644 --- a/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/MultiPhraseQuery.java @@ -24,7 +24,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; -import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; @@ -219,7 +218,6 @@ public class MultiPhraseQuery extends Query { @Override protected Similarity.SimScorer getStats(IndexSearcher searcher) throws IOException { - final IndexReaderContext context = searcher.getTopReaderContext(); // compute idf ArrayList allTermStats = new ArrayList<>(); @@ -227,7 +225,7 @@ public class MultiPhraseQuery extends Query { for (Term term : terms) { TermStates ts = termStates.get(term); if (ts == null) { - ts = TermStates.build(context, term, scoreMode.needsScores()); + ts = TermStates.build(searcher, term, scoreMode.needsScores()); termStates.put(term, ts); } if (scoreMode.needsScores() && ts.docFreq() > 0) { diff --git a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java index 64386165136..853e8f9b0c7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/PhraseQuery.java @@ -24,7 +24,6 @@ import java.util.Objects; import org.apache.lucene.codecs.lucene90.Lucene90PostingsFormat; import org.apache.lucene.codecs.lucene90.Lucene90PostingsReader; import org.apache.lucene.index.ImpactsEnum; -import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; @@ -451,13 +450,12 @@ public class PhraseQuery extends Query { throw new IllegalStateException( "PhraseWeight requires that the first position is 0, call rewrite first"); } - final IndexReaderContext context = searcher.getTopReaderContext(); states = new TermStates[terms.length]; TermStatistics[] termStats = new TermStatistics[terms.length]; int termUpTo = 0; for (int i = 0; i < terms.length; i++) { final Term term = terms[i]; - states[i] = TermStates.build(context, term, scoreMode.needsScores()); + states[i] = TermStates.build(searcher, term, scoreMode.needsScores()); if (scoreMode.needsScores()) { TermStates ts = states[i]; if (ts.docFreq() > 0) { diff --git a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java index 2ab11ed495a..f44ceaf3ac8 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java @@ -207,7 +207,7 @@ public final class SynonymQuery extends Query { termStates = new TermStates[terms.length]; for (int i = 0; i < termStates.length; i++) { Term term = new Term(field, terms[i].term); - TermStates ts = TermStates.build(searcher.getTopReaderContext(), term, true); + TermStates ts = TermStates.build(searcher, term, true); termStates[i] = ts; if (ts.docFreq() > 0) { TermStatistics termStats = diff --git a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java index 58251a3f6d7..dcf76c74ac7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TermQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/TermQuery.java @@ -272,7 +272,7 @@ public class TermQuery extends Query { final IndexReaderContext context = searcher.getTopReaderContext(); final TermStates termState; if (perReaderTermState == null || perReaderTermState.wasBuiltFor(context) == false) { - termState = TermStates.build(context, term, scoreMode.needsScores()); + termState = TermStates.build(searcher, term, scoreMode.needsScores()); } else { // PRTS was pre-build for this IS termState = this.perReaderTermState; diff --git a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java index be964543d42..3bccc17c0a6 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestFeatureField.java @@ -272,7 +272,8 @@ public class TestFeatureField extends LuceneTestCase { // Make sure that we create a legal pivot on missing features DirectoryReader reader = writer.getReader(); - float pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); + IndexSearcher searcher = LuceneTestCase.newSearcher(reader); + float pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank"); assertTrue(Float.isFinite(pivot)); assertTrue(pivot > 0); reader.close(); @@ -298,7 +299,8 @@ public class TestFeatureField extends LuceneTestCase { reader = writer.getReader(); writer.close(); - pivot = FeatureField.computePivotFeatureValue(reader, "features", "pagerank"); + searcher = LuceneTestCase.newSearcher(reader); + pivot = FeatureField.computePivotFeatureValue(searcher, "features", "pagerank"); double expected = Math.pow(10 * 100 * 1 * 42, 1 / 4.); // geometric mean assertEquals(expected, pivot, 0.1); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestTermStates.java b/lucene/core/src/test/org/apache/lucene/index/TestTermStates.java index 85142c2caa3..c5e69ce40af 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestTermStates.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestTermStates.java @@ -18,6 +18,7 @@ package org.apache.lucene.index; import org.apache.lucene.document.Document; +import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.store.Directory; import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.util.LuceneTestCase; @@ -30,8 +31,8 @@ public class TestTermStates extends LuceneTestCase { RandomIndexWriter w = new RandomIndexWriter(random(), dir); w.addDocument(new Document()); IndexReader r = w.getReader(); - TermStates states = - TermStates.build(r.getContext(), new Term("foo", "bar"), random().nextBoolean()); + IndexSearcher s = new IndexSearcher(r); + TermStates states = TermStates.build(s, new Term("foo", "bar"), random().nextBoolean()); assertEquals("TermStates\n state=null\n", states.toString()); IOUtils.close(r, w, dir); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestMinShouldMatch2.java b/lucene/core/src/test/org/apache/lucene/search/TestMinShouldMatch2.java index f3c8893f418..1b6d0778a2e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestMinShouldMatch2.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestMinShouldMatch2.java @@ -365,7 +365,7 @@ public class TestMinShouldMatch2 extends LuceneTestCase { if (ord >= 0) { boolean success = ords.add(ord); assert success; // no dups - TermStates ts = TermStates.build(reader.getContext(), term, true); + TermStates ts = TermStates.build(searcher, term, true); SimScorer w = weight.similarity.scorer( 1f, diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java index 473f6c9a8f0..72e8792b1be 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTermQuery.java @@ -55,11 +55,12 @@ public class TestTermQuery extends LuceneTestCase { final CompositeReaderContext context; try (MultiReader multiReader = new MultiReader()) { context = multiReader.getContext(); + IndexSearcher searcher = new IndexSearcher(context); + QueryUtils.checkEqual( + new TermQuery(new Term("foo", "bar")), + new TermQuery( + new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true))); } - QueryUtils.checkEqual( - new TermQuery(new Term("foo", "bar")), - new TermQuery( - new Term("foo", "bar"), TermStates.build(context, new Term("foo", "bar"), true))); } public void testCreateWeightDoesNotSeekIfScoresAreNotNeeded() throws IOException { @@ -100,8 +101,7 @@ public class TestTermQuery extends LuceneTestCase { assertEquals(1, totalHits); TermQuery queryWithContext = new TermQuery( - new Term("foo", "bar"), - TermStates.build(reader.getContext(), new Term("foo", "bar"), true)); + new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true)); totalHits = searcher.search(queryWithContext, DummyTotalHitCountCollector.createManager()); assertEquals(1, totalHits); @@ -160,10 +160,10 @@ public class TestTermQuery extends LuceneTestCase { w.addDocument(new Document()); DirectoryReader reader = w.getReader(); + IndexSearcher searcher = new IndexSearcher(reader); TermQuery queryWithContext = new TermQuery( - new Term("foo", "bar"), - TermStates.build(reader.getContext(), new Term("foo", "bar"), true)); + new Term("foo", "bar"), TermStates.build(searcher, new Term("foo", "bar"), true)); assertNotNull(queryWithContext.getTermStates()); IOUtils.close(reader, w, dir); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanTermQuery.java b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanTermQuery.java index a38019a39f2..39081eb48c1 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanTermQuery.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/spans/SpanTermQuery.java @@ -82,7 +82,7 @@ public class SpanTermQuery extends SpanQuery { final TermStates context; final IndexReaderContext topContext = searcher.getTopReaderContext(); if (termStates == null || termStates.wasBuiltFor(topContext) == false) { - context = TermStates.build(topContext, term, scoreMode.needsScores()); + context = TermStates.build(searcher, term, scoreMode.needsScores()); } else { context = termStates; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java index 08bb24a846c..f7afa5b4e56 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/CombinedFieldQuery.java @@ -330,7 +330,7 @@ public final class CombinedFieldQuery extends Query implements Accountable { termStates = new TermStates[fieldTerms.length]; for (int i = 0; i < termStates.length; i++) { FieldAndWeight field = fieldAndWeights.get(fieldTerms[i].field()); - TermStates ts = TermStates.build(searcher.getTopReaderContext(), fieldTerms[i], true); + TermStates ts = TermStates.build(searcher, fieldTerms[i], true); termStates[i] = ts; if (ts.docFreq() > 0) { TermStatistics termStats = diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java index bbe8a4970eb..cfda65505c1 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/PhraseWildcardQuery.java @@ -375,7 +375,7 @@ public class PhraseWildcardQuery extends Query { TermData termData = termsData.getOrCreateTermData(singleTerm.termPosition); Term term = singleTerm.term; termData.terms.add(term); - TermStates termStates = TermStates.build(searcher.getIndexReader().getContext(), term, true); + TermStates termStates = TermStates.build(searcher, term, true); // Collect TermState per segment. int numMatches = 0; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java index 7fae86711d8..f61f9d8fc00 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/search/TermAutomatonQuery.java @@ -23,7 +23,6 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import org.apache.lucene.index.IndexReaderContext; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.PostingsEnum; import org.apache.lucene.index.ReaderUtil; @@ -209,14 +208,13 @@ public class TermAutomatonQuery extends Query implements Accountable { @Override public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { - IndexReaderContext context = searcher.getTopReaderContext(); Map termStates = new HashMap<>(); for (Map.Entry ent : termToID.entrySet()) { if (ent.getKey() != null) { termStates.put( ent.getValue(), - TermStates.build(context, new Term(field, ent.getKey()), scoreMode.needsScores())); + TermStates.build(searcher, new Term(field, ent.getKey()), scoreMode.needsScores())); } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/ShardSearchingTestBase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/ShardSearchingTestBase.java index 0084e16631a..b72c7f3f8f4 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/ShardSearchingTestBase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/ShardSearchingTestBase.java @@ -207,7 +207,7 @@ public abstract class ShardSearchingTestBase extends LuceneTestCase { } try { for (Term term : terms) { - final TermStates ts = TermStates.build(s.getIndexReader().getContext(), term, true); + final TermStates ts = TermStates.build(s, term, true); if (ts.docFreq() > 0) { stats.put(term, s.termStatistics(term, ts.docFreq(), ts.totalTermFreq())); }