diff --git a/core/src/main/java/org/elasticsearch/search/profile/query/ProfileWeight.java b/core/src/main/java/org/elasticsearch/search/profile/query/ProfileWeight.java index 4361267bfe6..7cb50b29219 100644 --- a/core/src/main/java/org/elasticsearch/search/profile/query/ProfileWeight.java +++ b/core/src/main/java/org/elasticsearch/search/profile/query/ProfileWeight.java @@ -25,6 +25,7 @@ import org.apache.lucene.search.BulkScorer; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.Query; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.elasticsearch.search.profile.Timer; @@ -49,19 +50,50 @@ public final class ProfileWeight extends Weight { @Override public Scorer scorer(LeafReaderContext context) throws IOException { + ScorerSupplier supplier = scorerSupplier(context); + if (supplier == null) { + return null; + } + return supplier.get(false); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { Timer timer = profile.getTimer(QueryTimingType.BUILD_SCORER); timer.start(); - final Scorer subQueryScorer; + final ScorerSupplier subQueryScorerSupplier; try { - subQueryScorer = subQueryWeight.scorer(context); + subQueryScorerSupplier = subQueryWeight.scorerSupplier(context); } finally { timer.stop(); } - if (subQueryScorer == null) { + if (subQueryScorerSupplier == null) { return null; } - return new ProfileScorer(this, subQueryScorer, profile); + final ProfileWeight weight = this; + return new ScorerSupplier() { + + @Override + public Scorer get(boolean randomAccess) throws IOException { + timer.start(); + try { + return new ProfileScorer(weight, subQueryScorerSupplier.get(randomAccess), profile); + } finally { + timer.stop(); + } + } + + @Override + public long cost() { + timer.start(); + try { + return subQueryScorerSupplier.cost(); + } finally { + timer.stop(); + } + } + }; } @Override diff --git a/core/src/test/java/org/elasticsearch/search/profile/query/QueryProfilerTests.java b/core/src/test/java/org/elasticsearch/search/profile/query/QueryProfilerTests.java index 76de9d56e32..43c6018d8f8 100644 --- a/core/src/test/java/org/elasticsearch/search/profile/query/QueryProfilerTests.java +++ b/core/src/test/java/org/elasticsearch/search/profile/query/QueryProfilerTests.java @@ -22,16 +22,24 @@ package org.elasticsearch.search.profile.query; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.StringField; +import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.RandomIndexWriter; import org.apache.lucene.index.Term; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.LeafCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.RandomApproximationQuery; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Sort; import org.apache.lucene.search.TermQuery; import org.apache.lucene.search.TotalHitCountCollector; +import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.TestUtil; @@ -45,6 +53,7 @@ import org.junit.BeforeClass; import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.Set; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.greaterThan; @@ -191,4 +200,76 @@ public class QueryProfilerTests extends ESTestCase { leafCollector.collect(0); assertThat(profileCollector.getTime(), greaterThan(time)); } + + private static class DummyQuery extends Query { + + @Override + public String toString(String field) { + return getClass().getSimpleName(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public Weight createWeight(IndexSearcher searcher, boolean needsScores, float boost) throws IOException { + return new Weight(this) { + @Override + public void extractTerms(Set terms) { + throw new UnsupportedOperationException(); + } + + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException { + final Weight weight = this; + return new ScorerSupplier() { + + @Override + public Scorer get(boolean randomAccess) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return 42; + } + }; + } + }; + } + } + + public void testScorerSupplier() throws IOException { + Directory dir = newDirectory(); + IndexWriter w = new IndexWriter(dir, newIndexWriterConfig()); + w.addDocument(new Document()); + DirectoryReader reader = DirectoryReader.open(w); + w.close(); + IndexSearcher s = newSearcher(reader); + s.setQueryCache(null); + Weight weight = s.createNormalizedWeight(new DummyQuery(), randomBoolean()); + // exception when getting the scorer + expectThrows(UnsupportedOperationException.class, () -> weight.scorer(s.getIndexReader().leaves().get(0))); + // no exception, means scorerSupplier is delegated + weight.scorerSupplier(s.getIndexReader().leaves().get(0)); + reader.close(); + dir.close(); + } }