Concurrent Searching (Experimental) (#1500)

* Concurrent Searching (Experimental)

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>

* Addressingf code review comments

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
Andriy Redko 2022-03-24 14:20:31 -04:00 committed by GitHub
parent 2e9f89a89e
commit b6ca0d1f78
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
38 changed files with 5563 additions and 171 deletions

View File

@ -0,0 +1,42 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
apply plugin: 'opensearch.opensearchplugin'
apply plugin: 'opensearch.yaml-rest-test'
opensearchplugin {
name 'concurrent-search'
description 'The experimental plugin which implements concurrent search over Apache Lucene segments'
classname 'org.opensearch.search.ConcurrentSegmentSearchPlugin'
licenseFile rootProject.file('licenses/APACHE-LICENSE-2.0.txt')
noticeFile rootProject.file('NOTICE.txt')
}
yamlRestTest.enabled = false;
testingConventions.enabled = false;

View File

@ -0,0 +1,53 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.plugins.Plugin;
import org.opensearch.plugins.SearchPlugin;
import org.opensearch.search.query.ConcurrentQueryPhaseSearcher;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
/**
* The experimental plugin which implements the concurrent search over Apache Lucene segments.
*/
public class ConcurrentSegmentSearchPlugin extends Plugin implements SearchPlugin {
private static final String INDEX_SEARCHER = "index_searcher";
/**
* Default constructor
*/
public ConcurrentSegmentSearchPlugin() {}
@Override
public Optional<QueryPhaseSearcher> getQueryPhaseSearcher() {
return Optional.of(new ConcurrentQueryPhaseSearcher());
}
@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
final int allocatedProcessors = OpenSearchExecutors.allocatedProcessors(settings);
return Collections.singletonList(
new FixedExecutorBuilder(settings, INDEX_SEARCHER, allocatedProcessors, 1000, "thread_pool." + INDEX_SEARCHER)
);
}
@Override
public Optional<ExecutorServiceProvider> getIndexSearcherExecutorProvider() {
return Optional.of((ThreadPool threadPool) -> threadPool.executor(INDEX_SEARCHER));
}
}

View File

@ -0,0 +1,12 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/**
* The implementation of the experimental plugin which implements the concurrent search over Apache Lucene segments.
*/
package org.opensearch.search;

View File

@ -0,0 +1,119 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext;
import java.io.IOException;
import java.util.LinkedList;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.query.ProfileCollectorManager;
import org.opensearch.search.query.QueryPhase.DefaultQueryPhaseSearcher;
import org.opensearch.search.query.QueryPhase.TimeExceededException;
/**
* The implementation of the {@link QueryPhaseSearcher} which attempts to use concurrent
* search of Apache Lucene segments if it has been enabled.
*/
public class ConcurrentQueryPhaseSearcher extends DefaultQueryPhaseSearcher {
private static final Logger LOGGER = LogManager.getLogger(ConcurrentQueryPhaseSearcher.class);
/**
* Default constructor
*/
public ConcurrentQueryPhaseSearcher() {}
@Override
protected boolean searchWithCollector(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectors,
boolean hasFilterCollector,
boolean hasTimeout
) throws IOException {
boolean couldUseConcurrentSegmentSearch = allowConcurrentSegmentSearch(searcher);
// TODO: support aggregations
if (searchContext.aggregations() != null) {
couldUseConcurrentSegmentSearch = false;
LOGGER.debug("Unable to use concurrent search over index segments (experimental): aggregations are present");
}
if (couldUseConcurrentSegmentSearch) {
LOGGER.debug("Using concurrent search over index segments (experimental)");
return searchWithCollectorManager(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
} else {
return super.searchWithCollector(searchContext, searcher, query, collectors, hasFilterCollector, hasTimeout);
}
}
private static boolean searchWithCollectorManager(
SearchContext searchContext,
ContextIndexSearcher searcher,
Query query,
LinkedList<QueryCollectorContext> collectorContexts,
boolean hasFilterCollector,
boolean timeoutSet
) throws IOException {
// create the top docs collector last when the other collectors are known
final TopDocsCollectorContext topDocsFactory = createTopDocsCollectorContext(searchContext, hasFilterCollector);
// add the top docs collector, the first collector context in the chain
collectorContexts.addFirst(topDocsFactory);
final QuerySearchResult queryResult = searchContext.queryResult();
final CollectorManager<?, ReduceableSearchResult> collectorManager;
// TODO: support aggregations in concurrent segment search flow
if (searchContext.aggregations() != null) {
throw new UnsupportedOperationException("The concurrent segment search does not support aggregations yet");
}
if (searchContext.getProfilers() != null) {
final ProfileCollectorManager<? extends Collector, ReduceableSearchResult> profileCollectorManager =
QueryCollectorManagerContext.createQueryCollectorManagerWithProfiler(collectorContexts);
searchContext.getProfilers().getCurrentQueryProfiler().setCollector(profileCollectorManager);
collectorManager = profileCollectorManager;
} else {
// Create multi collector manager instance
collectorManager = QueryCollectorManagerContext.createMultiCollectorManager(collectorContexts);
}
try {
final ReduceableSearchResult result = searcher.search(query, collectorManager);
result.reduce(queryResult);
} catch (EarlyTerminatingCollector.EarlyTerminationException e) {
queryResult.terminatedEarly(true);
} catch (TimeExceededException e) {
assert timeoutSet : "TimeExceededException thrown even though timeout wasn't set";
if (searchContext.request().allowPartialSearchResults() == false) {
// Can't rethrow TimeExceededException because not serializable
throw new QueryPhaseExecutionException(searchContext.shardTarget(), "Time exceeded");
}
queryResult.searchTimedOut(true);
}
if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
queryResult.terminatedEarly(false);
}
return topDocsFactory.shouldRescore();
}
private static boolean allowConcurrentSegmentSearch(final ContextIndexSearcher searcher) {
return (searcher.getExecutor() != null);
}
}

View File

@ -0,0 +1,12 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
/**
* {@link org.opensearch.search.query.QueryPhaseSearcher} implementation for concurrent search
*/
package org.opensearch.search.query;

View File

@ -0,0 +1,316 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.profile.query;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.DirectoryReader;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.IndexWriter;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LRUQueryCache;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryCachingPolicy;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.store.Directory;
import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.tests.search.RandomApproximationQuery;
import org.apache.lucene.tests.util.TestUtil;
import org.opensearch.core.internal.io.IOUtils;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.profile.ProfileResult;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
public class QueryProfilerTests extends OpenSearchTestCase {
private Directory dir;
private IndexReader reader;
private ContextIndexSearcher searcher;
private ExecutorService executor;
@ParametersFactory
public static Collection<Object[]> concurrency() {
return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 });
}
public QueryProfilerTests(int concurrency) {
this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null;
}
@Before
public void setUp() throws Exception {
super.setUp();
dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir);
final int numDocs = TestUtil.nextInt(random(), 1, 20);
for (int i = 0; i < numDocs; ++i) {
final int numHoles = random().nextInt(5);
for (int j = 0; j < numHoles; ++j) {
w.addDocument(new Document());
}
Document doc = new Document();
doc.add(new StringField("foo", "bar", Store.NO));
w.addDocument(doc);
}
reader = w.getReader();
w.close();
searcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
ALWAYS_CACHE_POLICY,
true,
executor
);
}
@After
public void tearDown() throws Exception {
super.tearDown();
LRUQueryCache cache = (LRUQueryCache) searcher.getQueryCache();
assertThat(cache.getHitCount(), equalTo(0L));
assertThat(cache.getCacheCount(), equalTo(0L));
assertThat(cache.getTotalCount(), equalTo(cache.getMissCount()));
assertThat(cache.getCacheSize(), equalTo(0L));
if (executor != null) {
ThreadPool.terminate(executor, 10, TimeUnit.SECONDS);
}
IOUtils.close(reader, dir);
dir = null;
reader = null;
searcher = null;
}
public void testBasic() throws IOException {
QueryProfiler profiler = new QueryProfiler(executor != null);
searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar"));
searcher.search(query, 1);
List<ProfileResult> results = profiler.getTree();
assertEquals(1, results.size());
Map<String, Long> breakdown = results.get(0).getTimeBreakdown();
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), equalTo(0L));
long rewriteTime = profiler.getRewriteTime();
assertThat(rewriteTime, greaterThan(0L));
}
public void testNoScoring() throws IOException {
QueryProfiler profiler = new QueryProfiler(executor != null);
searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar"));
searcher.search(query, 1, Sort.INDEXORDER); // scores are not needed
List<ProfileResult> results = profiler.getTree();
assertEquals(1, results.size());
Map<String, Long> breakdown = results.get(0).getTimeBreakdown();
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), equalTo(0L));
long rewriteTime = profiler.getRewriteTime();
assertThat(rewriteTime, greaterThan(0L));
}
public void testUseIndexStats() throws IOException {
QueryProfiler profiler = new QueryProfiler(executor != null);
searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar"));
searcher.count(query); // will use index stats
List<ProfileResult> results = profiler.getTree();
assertEquals(1, results.size());
ProfileResult result = results.get(0);
assertEquals(0, (long) result.getTimeBreakdown().get("build_scorer_count"));
long rewriteTime = profiler.getRewriteTime();
assertThat(rewriteTime, greaterThan(0L));
}
public void testApproximations() throws IOException {
QueryProfiler profiler = new QueryProfiler(executor != null);
searcher.setProfiler(profiler);
Query query = new RandomApproximationQuery(new TermQuery(new Term("foo", "bar")), random());
searcher.count(query);
List<ProfileResult> results = profiler.getTree();
assertEquals(1, results.size());
Map<String, Long> breakdown = results.get(0).getTimeBreakdown();
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString()), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString()), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.CREATE_WEIGHT.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.BUILD_SCORER.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.NEXT_DOC.toString() + "_count"), greaterThan(0L));
assertThat(breakdown.get(QueryTimingType.ADVANCE.toString() + "_count"), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.SCORE.toString() + "_count"), equalTo(0L));
assertThat(breakdown.get(QueryTimingType.MATCH.toString() + "_count"), greaterThan(0L));
long rewriteTime = profiler.getRewriteTime();
assertThat(rewriteTime, greaterThan(0L));
}
public void testCollector() throws IOException {
TotalHitCountCollector collector = new TotalHitCountCollector();
ProfileCollector profileCollector = new ProfileCollector(collector);
assertEquals(0, profileCollector.getTime());
final LeafCollector leafCollector = profileCollector.getLeafCollector(reader.leaves().get(0));
assertThat(profileCollector.getTime(), greaterThan(0L));
long time = profileCollector.getTime();
leafCollector.setScorer(null);
assertThat(profileCollector.getTime(), greaterThan(time));
time = profileCollector.getTime();
leafCollector.collect(0);
assertThat(profileCollector.getTime(), greaterThan(time));
}
private static class DummyQuery extends Query {
@Override
public String toString(String field) {
return getClass().getSimpleName();
}
@Override
public boolean equals(Object obj) {
return this == obj;
}
@Override
public int hashCode() {
return 0;
}
@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new Weight(this) {
@Override
public Explanation explain(LeafReaderContext context, int doc) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
return new ScorerSupplier() {
@Override
public Scorer get(long loadCost) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public long cost() {
return 42;
}
};
}
@Override
public boolean isCacheable(LeafReaderContext ctx) {
return true;
}
};
}
@Override
public void visit(QueryVisitor visitor) {
visitor.visitLeaf(this);
}
}
public void testScorerSupplier() throws IOException {
Directory dir = newDirectory();
IndexWriter w = new IndexWriter(dir, newIndexWriterConfig());
w.addDocument(new Document());
DirectoryReader reader = DirectoryReader.open(w);
w.close();
IndexSearcher s = newSearcher(reader);
s.setQueryCache(null);
Weight weight = s.createWeight(s.rewrite(new DummyQuery()), randomFrom(ScoreMode.values()), 1f);
// exception when getting the scorer
expectThrows(UnsupportedOperationException.class, () -> weight.scorer(s.getIndexReader().leaves().get(0)));
// no exception, means scorerSupplier is delegated
weight.scorerSupplier(s.getIndexReader().leaves().get(0));
reader.close();
dir.close();
}
private static final QueryCachingPolicy ALWAYS_CACHE_POLICY = new QueryCachingPolicy() {
@Override
public void onUse(Query query) {}
@Override
public boolean shouldCache(Query query) throws IOException {
return true;
}
};
}

View File

@ -55,6 +55,10 @@ public class MinimumScoreCollector extends SimpleCollector {
this.minimumScore = minimumScore; this.minimumScore = minimumScore;
} }
public Collector getCollector() {
return collector;
}
@Override @Override
public void setScorer(Scorable scorer) throws IOException { public void setScorer(Scorable scorer) throws IOException {
if (!(scorer instanceof ScoreCachingWrappingScorer)) { if (!(scorer instanceof ScoreCachingWrappingScorer)) {

View File

@ -53,6 +53,10 @@ public class FilteredCollector implements Collector {
this.filter = filter; this.filter = filter;
} }
public Collector getCollector() {
return collector;
}
@Override @Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException { public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
final ScorerSupplier filterScorerSupplier = filter.scorerSupplier(context); final ScorerSupplier filterScorerSupplier = filter.scorerSupplier(context);

View File

@ -36,6 +36,7 @@ import org.apache.lucene.search.BooleanClause.Occur;
import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -82,6 +83,7 @@ import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.profile.Profilers; import org.opensearch.search.profile.Profilers;
import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.slice.SliceBuilder; import org.opensearch.search.slice.SliceBuilder;
import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.sort.SortAndFormats;
@ -163,7 +165,7 @@ final class DefaultSearchContext extends SearchContext {
private Profilers profilers; private Profilers profilers;
private final Map<String, SearchExtBuilder> searchExtBuilders = new HashMap<>(); private final Map<String, SearchExtBuilder> searchExtBuilders = new HashMap<>();
private final Map<Class<?>, Collector> queryCollectors = new HashMap<>(); private final Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
private final QueryShardContext queryShardContext; private final QueryShardContext queryShardContext;
private final FetchPhase fetchPhase; private final FetchPhase fetchPhase;
@ -823,8 +825,8 @@ final class DefaultSearchContext extends SearchContext {
} }
@Override @Override
public Map<Class<?>, Collector> queryCollectors() { public Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers() {
return queryCollectors; return queryCollectorManagers;
} }
@Override @Override

View File

@ -32,6 +32,7 @@
package org.opensearch.search.aggregations; package org.opensearch.search.aggregations;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.opensearch.common.inject.Inject; import org.opensearch.common.inject.Inject;
import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.Queries;
@ -40,9 +41,11 @@ import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.query.CollectorResult; import org.opensearch.search.profile.query.CollectorResult;
import org.opensearch.search.profile.query.InternalProfileCollector; import org.opensearch.search.profile.query.InternalProfileCollector;
import org.opensearch.search.query.QueryPhaseExecutionException; import org.opensearch.search.query.QueryPhaseExecutionException;
import org.opensearch.search.query.ReduceableSearchResult;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
@ -68,17 +71,18 @@ public class AggregationPhase {
} }
context.aggregations().aggregators(aggregators); context.aggregations().aggregators(aggregators);
if (!collectors.isEmpty()) { if (!collectors.isEmpty()) {
Collector collector = MultiBucketCollector.wrap(collectors); final Collector collector = createCollector(context, collectors);
((BucketCollector) collector).preCollection(); context.queryCollectorManagers().put(AggregationPhase.class, new CollectorManager<Collector, ReduceableSearchResult>() {
if (context.getProfilers() != null) { @Override
collector = new InternalProfileCollector( public Collector newCollector() throws IOException {
collector, return collector;
CollectorResult.REASON_AGGREGATION,
// TODO: report on child aggs as well
Collections.emptyList()
);
} }
context.queryCollectors().put(AggregationPhase.class, collector);
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
throw new UnsupportedOperationException("The concurrent aggregation over index segments is not supported");
}
});
} }
} catch (IOException e) { } catch (IOException e) {
throw new AggregationInitializationException("Could not initialize aggregators", e); throw new AggregationInitializationException("Could not initialize aggregators", e);
@ -147,6 +151,20 @@ public class AggregationPhase {
// disable aggregations so that they don't run on next pages in case of scrolling // disable aggregations so that they don't run on next pages in case of scrolling
context.aggregations(null); context.aggregations(null);
context.queryCollectors().remove(AggregationPhase.class); context.queryCollectorManagers().remove(AggregationPhase.class);
}
private Collector createCollector(SearchContext context, List<Aggregator> collectors) throws IOException {
Collector collector = MultiBucketCollector.wrap(collectors);
((BucketCollector) collector).preCollection();
if (context.getProfilers() != null) {
collector = new InternalProfileCollector(
collector,
CollectorResult.REASON_AGGREGATION,
// TODO: report on child aggs as well
Collections.emptyList()
);
}
return collector;
} }
} }

View File

@ -96,16 +96,6 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
private QueryProfiler profiler; private QueryProfiler profiler;
private MutableQueryTimeout cancellable; private MutableQueryTimeout cancellable;
public ContextIndexSearcher(
IndexReader reader,
Similarity similarity,
QueryCache queryCache,
QueryCachingPolicy queryCachingPolicy,
boolean wrapWithExitableDirectoryReader
) throws IOException {
this(reader, similarity, queryCache, queryCachingPolicy, new MutableQueryTimeout(), wrapWithExitableDirectoryReader, null);
}
public ContextIndexSearcher( public ContextIndexSearcher(
IndexReader reader, IndexReader reader,
Similarity similarity, Similarity similarity,
@ -233,6 +223,25 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
result.topDocs(new TopDocsAndMaxScore(mergedTopDocs, Float.NaN), formats); result.topDocs(new TopDocsAndMaxScore(mergedTopDocs, Float.NaN), formats);
} }
public void search(
Query query,
CollectorManager<?, TopFieldDocs> manager,
QuerySearchResult result,
DocValueFormat[] formats,
TotalHits totalHits
) throws IOException {
TopFieldDocs mergedTopDocs = search(query, manager);
// Lucene sets shards indexes during merging of topDocs from different collectors
// We need to reset shard index; OpenSearch will set shard index later during reduce stage
for (ScoreDoc scoreDoc : mergedTopDocs.scoreDocs) {
scoreDoc.shardIndex = -1;
}
if (totalHits != null) { // we have already precalculated totalHits for the whole index
mergedTopDocs = new TopFieldDocs(totalHits, mergedTopDocs.scoreDocs, mergedTopDocs.fields);
}
result.topDocs(new TopDocsAndMaxScore(mergedTopDocs, Float.NaN), formats);
}
@Override @Override
protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector) throws IOException { protected void search(List<LeafReaderContext> leaves, Weight weight, Collector collector) throws IOException {
for (LeafReaderContext ctx : leaves) { // search each subreader for (LeafReaderContext ctx : leaves) { // search each subreader
@ -420,8 +429,4 @@ public class ContextIndexSearcher extends IndexSearcher implements Releasable {
runnables.clear(); runnables.clear();
} }
} }
public boolean allowConcurrentSegmentSearch() {
return (getExecutor() != null);
}
} }

View File

@ -33,6 +33,7 @@
package org.opensearch.search.internal; package org.opensearch.search.internal;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchShardTask;
@ -61,6 +62,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext;
import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext;
import org.opensearch.search.profile.Profilers; import org.opensearch.search.profile.Profilers;
import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.sort.SortAndFormats;
import org.opensearch.search.suggest.SuggestionSearchContext; import org.opensearch.search.suggest.SuggestionSearchContext;
@ -492,8 +494,8 @@ public abstract class FilteredSearchContext extends SearchContext {
} }
@Override @Override
public Map<Class<?>, Collector> queryCollectors() { public Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers() {
return in.queryCollectors(); return in.queryCollectorManagers();
} }
@Override @Override

View File

@ -32,6 +32,7 @@
package org.opensearch.search.internal; package org.opensearch.search.internal;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.opensearch.action.search.SearchShardTask; import org.opensearch.action.search.SearchShardTask;
@ -66,6 +67,7 @@ import org.opensearch.search.fetch.subphase.ScriptFieldsContext;
import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext; import org.opensearch.search.fetch.subphase.highlight.SearchHighlightContext;
import org.opensearch.search.profile.Profilers; import org.opensearch.search.profile.Profilers;
import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.sort.SortAndFormats;
import org.opensearch.search.suggest.SuggestionSearchContext; import org.opensearch.search.suggest.SuggestionSearchContext;
@ -388,8 +390,8 @@ public abstract class SearchContext implements Releasable {
*/ */
public abstract long getRelativeTimeInMillis(); public abstract long getRelativeTimeInMillis();
/** Return a view of the additional query collectors that should be run for this context. */ /** Return a view of the additional query collector managers that should be run for this context. */
public abstract Map<Class<?>, Collector> queryCollectors(); public abstract Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers();
public abstract QueryShardContext getQueryShardContext(); public abstract QueryShardContext getQueryShardContext();

View File

@ -57,7 +57,7 @@ public final class Profilers {
/** Switch to a new profile. */ /** Switch to a new profile. */
public QueryProfiler addQueryProfiler() { public QueryProfiler addQueryProfiler() {
QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); QueryProfiler profiler = new QueryProfiler(searcher.getExecutor() != null);
searcher.setProfiler(profiler); searcher.setProfiler(profiler);
queryProfilers.add(profiler); queryProfilers.add(profiler);
return profiler; return profiler;

View File

@ -0,0 +1,89 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.profile.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.opensearch.search.query.EarlyTerminatingListener;
import org.opensearch.search.query.ReduceableSearchResult;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class InternalProfileCollectorManager
implements
ProfileCollectorManager<InternalProfileCollector, ReduceableSearchResult>,
EarlyTerminatingListener {
private final CollectorManager<? extends Collector, ReduceableSearchResult> manager;
private final String reason;
private final List<InternalProfileCollectorManager> children;
private long time = 0;
public InternalProfileCollectorManager(
CollectorManager<? extends Collector, ReduceableSearchResult> manager,
String reason,
List<InternalProfileCollectorManager> children
) {
this.manager = manager;
this.reason = reason;
this.children = children;
}
@Override
public InternalProfileCollector newCollector() throws IOException {
return new InternalProfileCollector(manager.newCollector(), reason, children);
}
@SuppressWarnings("unchecked")
@Override
public ReduceableSearchResult reduce(Collection<InternalProfileCollector> collectors) throws IOException {
final Collection<Collector> subs = new ArrayList<>();
for (final InternalProfileCollector collector : collectors) {
subs.add(collector.getCollector());
time += collector.getTime();
}
return ((CollectorManager<Collector, ReduceableSearchResult>) manager).reduce(subs);
}
@Override
public String getReason() {
return reason;
}
@Override
public long getTime() {
return time;
}
@Override
public Collection<? extends InternalProfileComponent> children() {
return children;
}
@Override
public String getName() {
return manager.getClass().getSimpleName();
}
@Override
public CollectorResult getCollectorTree() {
return InternalProfileCollector.doGetCollectorTree(this);
}
@Override
public void onEarlyTermination(int maxCountHits, boolean forcedTermination) {
if (manager instanceof EarlyTerminatingListener) {
((EarlyTerminatingListener) manager).onEarlyTermination(maxCountHits, forcedTermination);
}
}
}

View File

@ -0,0 +1,17 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.profile.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
/**
* Collector manager which supports profiling
*/
public interface ProfileCollectorManager<C extends Collector, T> extends CollectorManager<C, T>, InternalProfileComponent {}

View File

@ -95,6 +95,10 @@ public class EarlyTerminatingCollector extends FilterCollector {
}; };
} }
Collector getCollector() {
return in;
}
/** /**
* Returns true if this collector has early terminated. * Returns true if this collector has early terminated.
*/ */

View File

@ -0,0 +1,74 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public class EarlyTerminatingCollectorManager<C extends Collector>
implements
CollectorManager<EarlyTerminatingCollector, ReduceableSearchResult>,
EarlyTerminatingListener {
private final CollectorManager<C, ReduceableSearchResult> manager;
private final int maxCountHits;
private boolean forceTermination;
EarlyTerminatingCollectorManager(CollectorManager<C, ReduceableSearchResult> manager, int maxCountHits, boolean forceTermination) {
this.manager = manager;
this.maxCountHits = maxCountHits;
this.forceTermination = forceTermination;
}
@Override
public EarlyTerminatingCollector newCollector() throws IOException {
return new EarlyTerminatingCollector(manager.newCollector(), maxCountHits, false /* forced termination is not supported */);
}
@SuppressWarnings("unchecked")
@Override
public ReduceableSearchResult reduce(Collection<EarlyTerminatingCollector> collectors) throws IOException {
final List<C> innerCollectors = new ArrayList<>(collectors.size());
boolean didTerminateEarly = false;
for (EarlyTerminatingCollector collector : collectors) {
innerCollectors.add((C) collector.getCollector());
if (collector.hasEarlyTerminated()) {
didTerminateEarly = true;
}
}
if (didTerminateEarly) {
onEarlyTermination(maxCountHits, forceTermination);
final ReduceableSearchResult result = manager.reduce(innerCollectors);
return new ReduceableSearchResult() {
@Override
public void reduce(QuerySearchResult r) throws IOException {
result.reduce(r);
r.terminatedEarly(true);
}
};
}
return manager.reduce(innerCollectors);
}
@Override
public void onEarlyTermination(int maxCountHits, boolean forcedTermination) {
if (manager instanceof EarlyTerminatingListener) {
((EarlyTerminatingListener) manager).onEarlyTermination(maxCountHits, forcedTermination);
}
}
}

View File

@ -0,0 +1,22 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
/**
* Early termination event listener. It is used during concurrent segment search
* to propagate the early termination intent.
*/
public interface EarlyTerminatingListener {
/**
* Early termination event notification
* @param maxCountHits desired maximum number of hits
* @param forcedTermination :true" if forced termination has been requested, "false" otherwise
*/
void onEarlyTermination(int maxCountHits, boolean forcedTermination);
}

View File

@ -0,0 +1,45 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.search.FilteredCollector;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
class FilteredCollectorManager implements CollectorManager<FilteredCollector, ReduceableSearchResult> {
private final CollectorManager<? extends Collector, ReduceableSearchResult> manager;
private final Weight filter;
FilteredCollectorManager(CollectorManager<? extends Collector, ReduceableSearchResult> manager, Weight filter) {
this.manager = manager;
this.filter = filter;
}
@Override
public FilteredCollector newCollector() throws IOException {
return new FilteredCollector(manager.newCollector(), filter);
}
@Override
@SuppressWarnings("unchecked")
public ReduceableSearchResult reduce(Collection<FilteredCollector> collectors) throws IOException {
final Collection<Collector> subCollectors = new ArrayList<>();
for (final FilteredCollector collector : collectors) {
subCollectors.add(collector.getCollector());
}
return ((CollectorManager<Collector, ReduceableSearchResult>) manager).reduce(subCollectors);
}
}

View File

@ -0,0 +1,44 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.opensearch.common.lucene.MinimumScoreCollector;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
class MinimumCollectorManager implements CollectorManager<MinimumScoreCollector, ReduceableSearchResult> {
private final CollectorManager<? extends Collector, ReduceableSearchResult> manager;
private final float minimumScore;
MinimumCollectorManager(CollectorManager<? extends Collector, ReduceableSearchResult> manager, float minimumScore) {
this.manager = manager;
this.minimumScore = minimumScore;
}
@Override
public MinimumScoreCollector newCollector() throws IOException {
return new MinimumScoreCollector(manager.newCollector(), minimumScore);
}
@Override
@SuppressWarnings("unchecked")
public ReduceableSearchResult reduce(Collection<MinimumScoreCollector> collectors) throws IOException {
final Collection<Collector> subCollectors = new ArrayList<>();
for (final MinimumScoreCollector collector : collectors) {
subCollectors.add(collector.getCollector());
}
return ((CollectorManager<Collector, ReduceableSearchResult>) manager).reduce(subCollectors);
}
}

View File

@ -0,0 +1,58 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.ScoreMode;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
/**
* Wraps MultiCollector and provide access to underlying collectors.
* Please check out https://github.com/apache/lucene/pull/455.
*/
public class MultiCollectorWrapper implements Collector {
private final MultiCollector delegate;
private final Collection<Collector> collectors;
MultiCollectorWrapper(MultiCollector delegate, Collection<Collector> collectors) {
this.delegate = delegate;
this.collectors = collectors;
}
@Override
public LeafCollector getLeafCollector(LeafReaderContext context) throws IOException {
return delegate.getLeafCollector(context);
}
@Override
public ScoreMode scoreMode() {
return delegate.scoreMode();
}
public Collection<Collector> getCollectors() {
return collectors;
}
public static Collector wrap(Collector... collectors) {
final List<Collector> collectorsList = Arrays.asList(collectors);
final Collector collector = MultiCollector.wrap(collectorsList);
if (collector instanceof MultiCollector) {
return new MultiCollectorWrapper((MultiCollector) collector, collectorsList);
} else {
return collector;
}
}
}

View File

@ -33,6 +33,7 @@
package org.opensearch.search.query; package org.opensearch.search.query;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
@ -42,6 +43,7 @@ import org.apache.lucene.search.Weight;
import org.opensearch.common.lucene.MinimumScoreCollector; import org.opensearch.common.lucene.MinimumScoreCollector;
import org.opensearch.common.lucene.search.FilteredCollector; import org.opensearch.common.lucene.search.FilteredCollector;
import org.opensearch.search.profile.query.InternalProfileCollector; import org.opensearch.search.profile.query.InternalProfileCollector;
import org.opensearch.search.profile.query.InternalProfileCollectorManager;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
@ -54,7 +56,7 @@ import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_
import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_POST_FILTER; import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_POST_FILTER;
import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_TERMINATE_AFTER_COUNT; import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_TERMINATE_AFTER_COUNT;
abstract class QueryCollectorContext { public abstract class QueryCollectorContext {
private static final Collector EMPTY_COLLECTOR = new SimpleCollector() { private static final Collector EMPTY_COLLECTOR = new SimpleCollector() {
@Override @Override
public void collect(int doc) {} public void collect(int doc) {}
@ -77,6 +79,8 @@ abstract class QueryCollectorContext {
*/ */
abstract Collector create(Collector in) throws IOException; abstract Collector create(Collector in) throws IOException;
abstract CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException;
/** /**
* Wraps this collector with a profiler * Wraps this collector with a profiler
*/ */
@ -85,6 +89,18 @@ abstract class QueryCollectorContext {
return new InternalProfileCollector(collector, profilerName, in != null ? Collections.singletonList(in) : Collections.emptyList()); return new InternalProfileCollector(collector, profilerName, in != null ? Collections.singletonList(in) : Collections.emptyList());
} }
/**
* Wraps this collector manager with a profiler
*/
protected InternalProfileCollectorManager createWithProfiler(InternalProfileCollectorManager in) throws IOException {
final CollectorManager<? extends Collector, ReduceableSearchResult> manager = createManager(in);
return new InternalProfileCollectorManager(
manager,
profilerName,
in != null ? Collections.singletonList(in) : Collections.emptyList()
);
}
/** /**
* Post-process <code>result</code> after search execution. * Post-process <code>result</code> after search execution.
* *
@ -126,6 +142,11 @@ abstract class QueryCollectorContext {
Collector create(Collector in) { Collector create(Collector in) {
return new MinimumScoreCollector(in, minScore); return new MinimumScoreCollector(in, minScore);
} }
@Override
CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException {
return new MinimumCollectorManager(in, minScore);
}
}; };
} }
@ -139,35 +160,58 @@ abstract class QueryCollectorContext {
final Weight filterWeight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f); final Weight filterWeight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f);
return new FilteredCollector(in, filterWeight); return new FilteredCollector(in, filterWeight);
} }
@Override
CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException {
final Weight filterWeight = searcher.createWeight(searcher.rewrite(query), ScoreMode.COMPLETE_NO_SCORES, 1f);
return new FilteredCollectorManager(in, filterWeight);
}
}; };
} }
/** /**
* Creates a multi collector from the provided <code>subs</code> * Creates a multi collector manager from the provided <code>subs</code>
*/ */
static QueryCollectorContext createMultiCollectorContext(Collection<Collector> subs) { static QueryCollectorContext createMultiCollectorContext(
Collection<CollectorManager<? extends Collector, ReduceableSearchResult>> subs
) {
return new QueryCollectorContext(REASON_SEARCH_MULTI) { return new QueryCollectorContext(REASON_SEARCH_MULTI) {
@Override @Override
Collector create(Collector in) { Collector create(Collector in) throws IOException {
List<Collector> subCollectors = new ArrayList<>(); List<Collector> subCollectors = new ArrayList<>();
subCollectors.add(in); subCollectors.add(in);
subCollectors.addAll(subs); for (CollectorManager<? extends Collector, ReduceableSearchResult> manager : subs) {
subCollectors.add(manager.newCollector());
}
return MultiCollector.wrap(subCollectors); return MultiCollector.wrap(subCollectors);
} }
@Override @Override
protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) { protected InternalProfileCollector createWithProfiler(InternalProfileCollector in) throws IOException {
final List<InternalProfileCollector> subCollectors = new ArrayList<>(); final List<InternalProfileCollector> subCollectors = new ArrayList<>();
subCollectors.add(in); subCollectors.add(in);
if (subs.stream().anyMatch((col) -> col instanceof InternalProfileCollector == false)) {
for (CollectorManager<? extends Collector, ReduceableSearchResult> manager : subs) {
final Collector collector = manager.newCollector();
if (!(collector instanceof InternalProfileCollector)) {
throw new IllegalArgumentException("non-profiling collector"); throw new IllegalArgumentException("non-profiling collector");
} }
for (Collector collector : subs) {
subCollectors.add((InternalProfileCollector) collector); subCollectors.add((InternalProfileCollector) collector);
} }
final Collector collector = MultiCollector.wrap(subCollectors); final Collector collector = MultiCollector.wrap(subCollectors);
return new InternalProfileCollector(collector, REASON_SEARCH_MULTI, subCollectors); return new InternalProfileCollector(collector, REASON_SEARCH_MULTI, subCollectors);
} }
@Override
CollectorManager<? extends Collector, ReduceableSearchResult> createManager(
CollectorManager<? extends Collector, ReduceableSearchResult> in
) throws IOException {
final List<CollectorManager<?, ReduceableSearchResult>> managers = new ArrayList<>();
managers.add(in);
managers.addAll(subs);
return QueryCollectorManagerContext.createOpaqueCollectorManager(managers);
}
}; };
} }
@ -192,6 +236,13 @@ abstract class QueryCollectorContext {
this.collector = MultiCollector.wrap(subCollectors); this.collector = MultiCollector.wrap(subCollectors);
return collector; return collector;
} }
@Override
CollectorManager<? extends Collector, ReduceableSearchResult> createManager(
CollectorManager<? extends Collector, ReduceableSearchResult> in
) throws IOException {
return new EarlyTerminatingCollectorManager<>(in, numHits, true);
}
}; };
} }
} }

View File

@ -0,0 +1,99 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.MultiCollectorManager;
import org.opensearch.search.profile.query.InternalProfileCollectorManager;
import org.opensearch.search.profile.query.ProfileCollectorManager;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
public abstract class QueryCollectorManagerContext {
private static class QueryCollectorManager implements CollectorManager<Collector, ReduceableSearchResult> {
private final MultiCollectorManager manager;
private QueryCollectorManager(Collection<CollectorManager<? extends Collector, ReduceableSearchResult>> managers) {
this.manager = new MultiCollectorManager(managers.toArray(new CollectorManager<?, ?>[0]));
}
@Override
public Collector newCollector() throws IOException {
return manager.newCollector();
}
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final Object[] results = manager.reduce(collectors);
final ReduceableSearchResult[] transformed = new ReduceableSearchResult[results.length];
for (int i = 0; i < results.length; ++i) {
assert results[i] instanceof ReduceableSearchResult;
transformed[i] = (ReduceableSearchResult) results[i];
}
return reduceWith(transformed);
}
protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] results) {
return (QuerySearchResult result) -> {
for (final ReduceableSearchResult r : results) {
r.reduce(result);
}
};
}
}
private static class OpaqueQueryCollectorManager extends QueryCollectorManager {
private OpaqueQueryCollectorManager(Collection<CollectorManager<? extends Collector, ReduceableSearchResult>> managers) {
super(managers);
}
@Override
protected ReduceableSearchResult reduceWith(final ReduceableSearchResult[] results) {
return (QuerySearchResult result) -> {};
}
}
public static CollectorManager<? extends Collector, ReduceableSearchResult> createOpaqueCollectorManager(
List<CollectorManager<? extends Collector, ReduceableSearchResult>> managers
) throws IOException {
return new OpaqueQueryCollectorManager(managers);
}
public static CollectorManager<? extends Collector, ReduceableSearchResult> createMultiCollectorManager(
List<QueryCollectorContext> collectors
) throws IOException {
final Collection<CollectorManager<? extends Collector, ReduceableSearchResult>> managers = new ArrayList<>();
CollectorManager<?, ReduceableSearchResult> manager = null;
for (QueryCollectorContext ctx : collectors) {
manager = ctx.createManager(manager);
managers.add(manager);
}
return new QueryCollectorManager(managers);
}
public static ProfileCollectorManager<? extends Collector, ReduceableSearchResult> createQueryCollectorManagerWithProfiler(
List<QueryCollectorContext> collectors
) throws IOException {
InternalProfileCollectorManager manager = null;
for (QueryCollectorContext ctx : collectors) {
manager = ctx.createWithProfiler(manager);
}
return manager;
}
}

View File

@ -238,9 +238,9 @@ public class QueryPhase {
// this collector can filter documents during the collection // this collector can filter documents during the collection
hasFilterCollector = true; hasFilterCollector = true;
} }
if (searchContext.queryCollectors().isEmpty() == false) { if (searchContext.queryCollectorManagers().isEmpty() == false) {
// plug in additional collectors, like aggregations // plug in additional collectors, like aggregations
collectors.add(createMultiCollectorContext(searchContext.queryCollectors().values())); collectors.add(createMultiCollectorContext(searchContext.queryCollectorManagers().values()));
} }
if (searchContext.minimumScore() != null) { if (searchContext.minimumScore() != null) {
// apply the minimum score after multi collector so we filter aggs as well // apply the minimum score after multi collector so we filter aggs as well

View File

@ -0,0 +1,23 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import java.io.IOException;
/**
* The search result callback returned by reduce phase of the collector manager.
*/
public interface ReduceableSearchResult {
/**
* Apply the reduce operation to the query search results
* @param result query search results
* @throws IOException exception if reduce operation failed
*/
void reduce(QuerySearchResult result) throws IOException;
}

View File

@ -44,6 +44,7 @@ import org.apache.lucene.queries.spans.SpanQuery;
import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BoostQuery; import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.ConstantScoreQuery; import org.apache.lucene.search.ConstantScoreQuery;
import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.DocValuesFieldExistsQuery;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
@ -80,6 +81,9 @@ import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.sort.SortAndFormats;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects; import java.util.Objects;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -89,7 +93,7 @@ import static org.opensearch.search.profile.query.CollectorResult.REASON_SEARCH_
/** /**
* A {@link QueryCollectorContext} that creates top docs collector * A {@link QueryCollectorContext} that creates top docs collector
*/ */
abstract class TopDocsCollectorContext extends QueryCollectorContext { public abstract class TopDocsCollectorContext extends QueryCollectorContext {
protected final int numHits; protected final int numHits;
TopDocsCollectorContext(String profilerName, int numHits) { TopDocsCollectorContext(String profilerName, int numHits) {
@ -107,7 +111,7 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
/** /**
* Returns true if the top docs should be re-scored after initial search * Returns true if the top docs should be re-scored after initial search
*/ */
boolean shouldRescore() { public boolean shouldRescore() {
return false; return false;
} }
@ -115,6 +119,8 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
private final Sort sort; private final Sort sort;
private final Collector collector; private final Collector collector;
private final Supplier<TotalHits> hitCountSupplier; private final Supplier<TotalHits> hitCountSupplier;
private final int trackTotalHitsUpTo;
private final int hitCount;
/** /**
* Ctr * Ctr
@ -132,16 +138,18 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
) throws IOException { ) throws IOException {
super(REASON_SEARCH_COUNT, 0); super(REASON_SEARCH_COUNT, 0);
this.sort = sortAndFormats == null ? null : sortAndFormats.sort; this.sort = sortAndFormats == null ? null : sortAndFormats.sort;
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { this.trackTotalHitsUpTo = trackTotalHitsUpTo;
if (this.trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
this.collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 0, false); this.collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 0, false);
// for bwc hit count is set to 0, it will be converted to -1 by the coordinating node // for bwc hit count is set to 0, it will be converted to -1 by the coordinating node
this.hitCountSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); this.hitCountSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
this.hitCount = Integer.MIN_VALUE;
} else { } else {
TotalHitCountCollector hitCountCollector = new TotalHitCountCollector(); TotalHitCountCollector hitCountCollector = new TotalHitCountCollector();
// implicit total hit counts are valid only when there is no filter collector in the chain // implicit total hit counts are valid only when there is no filter collector in the chain
int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query);
if (hitCount == -1) { if (this.hitCount == -1) {
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_ACCURATE) { if (this.trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
this.collector = hitCountCollector; this.collector = hitCountCollector;
this.hitCountSupplier = () -> new TotalHits(hitCountCollector.getTotalHits(), TotalHits.Relation.EQUAL_TO); this.hitCountSupplier = () -> new TotalHits(hitCountCollector.getTotalHits(), TotalHits.Relation.EQUAL_TO);
} else { } else {
@ -159,6 +167,39 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
} }
} }
@Override
CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException {
assert in == null;
CollectorManager<?, ReduceableSearchResult> manager = null;
if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
manager = new EarlyTerminatingCollectorManager<>(
new TotalHitCountCollectorManager.Empty(new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), sort),
0,
false
);
} else {
if (hitCount == -1) {
if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_ACCURATE) {
manager = new EarlyTerminatingCollectorManager<>(
new TotalHitCountCollectorManager(sort),
trackTotalHitsUpTo,
false
);
}
} else {
manager = new EarlyTerminatingCollectorManager<>(
new TotalHitCountCollectorManager.Empty(new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO), sort),
0,
false
);
}
}
return manager;
}
@Override @Override
Collector create(Collector in) { Collector create(Collector in) {
assert in == null; assert in == null;
@ -181,7 +222,11 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
static class CollapsingTopDocsCollectorContext extends TopDocsCollectorContext { static class CollapsingTopDocsCollectorContext extends TopDocsCollectorContext {
private final DocValueFormat[] sortFmt; private final DocValueFormat[] sortFmt;
private final CollapsingTopDocsCollector<?> topDocsCollector; private final CollapsingTopDocsCollector<?> topDocsCollector;
private final Collector collector;
private final Supplier<Float> maxScoreSupplier; private final Supplier<Float> maxScoreSupplier;
private final CollapseContext collapseContext;
private final boolean trackMaxScore;
private final Sort sort;
/** /**
* Ctr * Ctr
@ -199,30 +244,94 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
super(REASON_SEARCH_TOP_HITS, numHits); super(REASON_SEARCH_TOP_HITS, numHits);
assert numHits > 0; assert numHits > 0;
assert collapseContext != null; assert collapseContext != null;
Sort sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort; this.sort = sortAndFormats == null ? Sort.RELEVANCE : sortAndFormats.sort;
this.sortFmt = sortAndFormats == null ? new DocValueFormat[] { DocValueFormat.RAW } : sortAndFormats.formats; this.sortFmt = sortAndFormats == null ? new DocValueFormat[] { DocValueFormat.RAW } : sortAndFormats.formats;
this.collapseContext = collapseContext;
this.topDocsCollector = collapseContext.createTopDocs(sort, numHits); this.topDocsCollector = collapseContext.createTopDocs(sort, numHits);
this.trackMaxScore = trackMaxScore;
MaxScoreCollector maxScoreCollector; MaxScoreCollector maxScoreCollector = null;
if (trackMaxScore) { if (trackMaxScore) {
maxScoreCollector = new MaxScoreCollector(); maxScoreCollector = new MaxScoreCollector();
maxScoreSupplier = maxScoreCollector::getMaxScore; maxScoreSupplier = maxScoreCollector::getMaxScore;
} else { } else {
maxScoreCollector = null;
maxScoreSupplier = () -> Float.NaN; maxScoreSupplier = () -> Float.NaN;
} }
this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector);
} }
@Override @Override
Collector create(Collector in) throws IOException { Collector create(Collector in) throws IOException {
assert in == null; assert in == null;
return topDocsCollector; return collector;
} }
@Override @Override
void postProcess(QuerySearchResult result) throws IOException { void postProcess(QuerySearchResult result) throws IOException {
CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs(); final CollapseTopFieldDocs topDocs = topDocsCollector.getTopDocs();
result.topDocs(new TopDocsAndMaxScore(topDocs, maxScoreSupplier.get()), sortFmt); result.topDocs(new TopDocsAndMaxScore(topDocs, maxScoreSupplier.get()), sortFmt);
} }
@Override
CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException {
return new CollectorManager<Collector, ReduceableSearchResult>() {
@Override
public Collector newCollector() throws IOException {
MaxScoreCollector maxScoreCollector = null;
if (trackMaxScore) {
maxScoreCollector = new MaxScoreCollector();
}
return MultiCollectorWrapper.wrap(collapseContext.createTopDocs(sort, numHits), maxScoreCollector);
}
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final Collection<Collector> subs = new ArrayList<>();
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
subs.addAll(((MultiCollectorWrapper) collector).getCollectors());
} else {
subs.add(collector);
}
}
final Collection<CollapseTopFieldDocs> topFieldDocs = new ArrayList<CollapseTopFieldDocs>();
float maxScore = Float.NaN;
for (final Collector collector : subs) {
if (collector instanceof CollapsingTopDocsCollector<?>) {
topFieldDocs.add(((CollapsingTopDocsCollector<?>) collector).getTopDocs());
} else if (collector instanceof MaxScoreCollector) {
float score = ((MaxScoreCollector) collector).getMaxScore();
if (Float.isNaN(maxScore)) {
maxScore = score;
} else {
maxScore = Math.max(maxScore, score);
}
}
}
return reduceWith(topFieldDocs, maxScore);
}
};
}
protected ReduceableSearchResult reduceWith(final Collection<CollapseTopFieldDocs> topFieldDocs, float maxScore) {
return (QuerySearchResult result) -> {
final CollapseTopFieldDocs topDocs = CollapseTopFieldDocs.merge(
sort,
0,
numHits,
topFieldDocs.toArray(new CollapseTopFieldDocs[0]),
true
);
result.topDocs(new TopDocsAndMaxScore(topDocs, maxScore), sortFmt);
};
}
} }
abstract static class SimpleTopDocsCollectorContext extends TopDocsCollectorContext { abstract static class SimpleTopDocsCollectorContext extends TopDocsCollectorContext {
@ -240,11 +349,38 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
} }
} }
private static CollectorManager<? extends TopDocsCollector<?>, ? extends TopDocs> createCollectorManager(
@Nullable SortAndFormats sortAndFormats,
int numHits,
@Nullable ScoreDoc searchAfter,
int hitCountThreshold
) {
if (sortAndFormats == null) {
// See please https://github.com/apache/lucene/pull/450, should be fixed in 9.x
if (searchAfter != null) {
return TopScoreDocCollector.createSharedManager(
numHits,
new FieldDoc(searchAfter.doc, searchAfter.score),
hitCountThreshold
);
} else {
return TopScoreDocCollector.createSharedManager(numHits, null, hitCountThreshold);
}
} else {
return TopFieldCollector.createSharedManager(sortAndFormats.sort, numHits, (FieldDoc) searchAfter, hitCountThreshold);
}
}
protected final @Nullable SortAndFormats sortAndFormats; protected final @Nullable SortAndFormats sortAndFormats;
private final Collector collector; private final Collector collector;
private final Supplier<TotalHits> totalHitsSupplier; private final Supplier<TotalHits> totalHitsSupplier;
private final Supplier<TopDocs> topDocsSupplier; private final Supplier<TopDocs> topDocsSupplier;
private final Supplier<Float> maxScoreSupplier; private final Supplier<Float> maxScoreSupplier;
private final ScoreDoc searchAfter;
private final int trackTotalHitsUpTo;
private final boolean trackMaxScore;
private final boolean hasInfMaxScore;
private final int hitCount;
/** /**
* Ctr * Ctr
@ -269,24 +405,30 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
) throws IOException { ) throws IOException {
super(REASON_SEARCH_TOP_HITS, numHits); super(REASON_SEARCH_TOP_HITS, numHits);
this.sortAndFormats = sortAndFormats; this.sortAndFormats = sortAndFormats;
this.searchAfter = searchAfter;
this.trackTotalHitsUpTo = trackTotalHitsUpTo;
this.trackMaxScore = trackMaxScore;
this.hasInfMaxScore = hasInfMaxScore(query);
final TopDocsCollector<?> topDocsCollector; final TopDocsCollector<?> topDocsCollector;
if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore(query)) { if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) {
// disable max score optimization since we have a mandatory clause // disable max score optimization since we have a mandatory clause
// that doesn't track the maximum score // that doesn't track the maximum score
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE); topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> topDocsSupplier.get().totalHits; totalHitsSupplier = () -> topDocsSupplier.get().totalHits;
hitCount = Integer.MIN_VALUE;
} else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) { } else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
// don't compute hit counts via the collector // don't compute hit counts via the collector
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); totalHitsSupplier = () -> new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
hitCount = -1;
} else { } else {
// implicit total hit counts are valid only when there is no filter collector in the chain // implicit total hit counts are valid only when there is no filter collector in the chain
final int hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query); this.hitCount = hasFilterCollector ? -1 : shortcutTotalHitCount(reader, query);
if (hitCount == -1) { if (this.hitCount == -1) {
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo); topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> topDocsSupplier.get().totalHits; totalHitsSupplier = () -> topDocsSupplier.get().totalHits;
@ -294,7 +436,7 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
// don't compute hit counts via the collector // don't compute hit counts via the collector
topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1); topDocsCollector = createCollector(sortAndFormats, numHits, searchAfter, 1);
topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs); topDocsSupplier = new CachedSupplier<>(topDocsCollector::topDocs);
totalHitsSupplier = () -> new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO); totalHitsSupplier = () -> new TotalHits(this.hitCount, TotalHits.Relation.EQUAL_TO);
} }
} }
MaxScoreCollector maxScoreCollector = null; MaxScoreCollector maxScoreCollector = null;
@ -315,7 +457,98 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
} }
this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector); this.collector = MultiCollector.wrap(topDocsCollector, maxScoreCollector);
}
private class SimpleTopDocsCollectorManager
implements
CollectorManager<Collector, ReduceableSearchResult>,
EarlyTerminatingListener {
private Integer terminatedAfter;
private final CollectorManager<? extends TopDocsCollector<?>, ? extends TopDocs> manager;
private SimpleTopDocsCollectorManager() {
if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) {
// disable max score optimization since we have a mandatory clause
// that doesn't track the maximum score
manager = createCollectorManager(sortAndFormats, numHits, searchAfter, Integer.MAX_VALUE);
} else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
// don't compute hit counts via the collector
manager = createCollectorManager(sortAndFormats, numHits, searchAfter, 1);
} else {
// implicit total hit counts are valid only when there is no filter collector in the chain
if (hitCount == -1) {
manager = createCollectorManager(sortAndFormats, numHits, searchAfter, trackTotalHitsUpTo);
} else {
// don't compute hit counts via the collector
manager = createCollectorManager(sortAndFormats, numHits, searchAfter, 1);
}
}
}
@Override
public void onEarlyTermination(int maxCountHits, boolean forcedTermination) {
terminatedAfter = maxCountHits;
}
@Override
public Collector newCollector() throws IOException {
MaxScoreCollector maxScoreCollector = null;
if (sortAndFormats != null && trackMaxScore) {
maxScoreCollector = new MaxScoreCollector();
}
return MultiCollectorWrapper.wrap(manager.newCollector(), maxScoreCollector);
}
@SuppressWarnings("unchecked")
@Override
public ReduceableSearchResult reduce(Collection<Collector> collectors) throws IOException {
final Collection<TopDocsCollector<?>> topDocsCollectors = new ArrayList<>();
final Collection<MaxScoreCollector> maxScoreCollectors = new ArrayList<>();
for (final Collector collector : collectors) {
if (collector instanceof MultiCollectorWrapper) {
for (final Collector sub : (((MultiCollectorWrapper) collector).getCollectors())) {
if (sub instanceof TopDocsCollector<?>) {
topDocsCollectors.add((TopDocsCollector<?>) sub);
} else if (sub instanceof MaxScoreCollector) {
maxScoreCollectors.add((MaxScoreCollector) sub);
}
}
} else if (collector instanceof TopDocsCollector<?>) {
topDocsCollectors.add((TopDocsCollector<?>) collector);
} else if (collector instanceof MaxScoreCollector) {
maxScoreCollectors.add((MaxScoreCollector) collector);
}
}
float maxScore = Float.NaN;
for (final MaxScoreCollector collector : maxScoreCollectors) {
float score = collector.getMaxScore();
if (Float.isNaN(maxScore)) {
maxScore = score;
} else {
maxScore = Math.max(maxScore, score);
}
}
final TopDocs topDocs = ((CollectorManager<TopDocsCollector<?>, ? extends TopDocs>) manager).reduce(topDocsCollectors);
return reduceWith(topDocs, maxScore, terminatedAfter);
}
}
@Override
CollectorManager<?, ReduceableSearchResult> createManager(CollectorManager<?, ReduceableSearchResult> in) throws IOException {
assert in == null;
return new SimpleTopDocsCollectorManager();
}
protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) {
return (QuerySearchResult result) -> {
final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter);
result.topDocs(topDocsAndMaxScore, sortAndFormats == null ? null : sortAndFormats.formats);
};
} }
@Override @Override
@ -324,6 +557,50 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
return collector; return collector;
} }
TopDocsAndMaxScore newTopDocs(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) {
TotalHits totalHits = null;
if ((sortAndFormats == null || SortField.FIELD_SCORE.equals(sortAndFormats.sort.getSort()[0])) && hasInfMaxScore) {
totalHits = topDocs.totalHits;
} else if (trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED) {
// don't compute hit counts via the collector
totalHits = new TotalHits(0, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
} else {
if (hitCount == -1) {
totalHits = topDocs.totalHits;
} else {
totalHits = new TotalHits(hitCount, TotalHits.Relation.EQUAL_TO);
}
}
// Since we cannot support early forced termination, we have to simulate it by
// artificially reducing the number of total hits and doc scores.
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
if (terminatedAfter != null) {
if (totalHits.value > terminatedAfter) {
totalHits = new TotalHits(terminatedAfter, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
}
if (scoreDocs != null && scoreDocs.length > terminatedAfter) {
scoreDocs = Arrays.copyOf(scoreDocs, terminatedAfter);
}
}
final TopDocs newTopDocs;
if (topDocs instanceof TopFieldDocs) {
TopFieldDocs fieldDocs = (TopFieldDocs) topDocs;
newTopDocs = new TopFieldDocs(totalHits, scoreDocs, fieldDocs.fields);
} else {
newTopDocs = new TopDocs(totalHits, scoreDocs);
}
if (Float.isNaN(maxScore) && newTopDocs.scoreDocs.length > 0 && sortAndFormats == null) {
return new TopDocsAndMaxScore(newTopDocs, newTopDocs.scoreDocs[0].score);
} else {
return new TopDocsAndMaxScore(newTopDocs, maxScore);
}
}
TopDocsAndMaxScore newTopDocs() { TopDocsAndMaxScore newTopDocs() {
TopDocs in = topDocsSupplier.get(); TopDocs in = topDocsSupplier.get();
float maxScore = maxScoreSupplier.get(); float maxScore = maxScoreSupplier.get();
@ -373,6 +650,35 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
this.numberOfShards = numberOfShards; this.numberOfShards = numberOfShards;
} }
@Override
protected ReduceableSearchResult reduceWith(final TopDocs topDocs, final float maxScore, final Integer terminatedAfter) {
return (QuerySearchResult result) -> {
final TopDocsAndMaxScore topDocsAndMaxScore = newTopDocs(topDocs, maxScore, terminatedAfter);
if (scrollContext.totalHits == null) {
// first round
scrollContext.totalHits = topDocsAndMaxScore.topDocs.totalHits;
scrollContext.maxScore = topDocsAndMaxScore.maxScore;
} else {
// subsequent round: the total number of hits and
// the maximum score were computed on the first round
topDocsAndMaxScore.topDocs.totalHits = scrollContext.totalHits;
topDocsAndMaxScore.maxScore = scrollContext.maxScore;
}
if (numberOfShards == 1) {
// if we fetch the document in the same roundtrip, we already know the last emitted doc
if (topDocsAndMaxScore.topDocs.scoreDocs.length > 0) {
// set the last emitted doc
scrollContext.lastEmittedDoc = topDocsAndMaxScore.topDocs.scoreDocs[topDocsAndMaxScore.topDocs.scoreDocs.length
- 1];
}
}
result.topDocs(topDocsAndMaxScore, sortAndFormats == null ? null : sortAndFormats.formats);
};
}
@Override @Override
void postProcess(QuerySearchResult result) throws IOException { void postProcess(QuerySearchResult result) throws IOException {
final TopDocsAndMaxScore topDocs = newTopDocs(); final TopDocsAndMaxScore topDocs = newTopDocs();
@ -457,7 +763,7 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
* Creates a {@link TopDocsCollectorContext} from the provided <code>searchContext</code>. * Creates a {@link TopDocsCollectorContext} from the provided <code>searchContext</code>.
* @param hasFilterCollector True if the collector chain contains at least one collector that can filters document. * @param hasFilterCollector True if the collector chain contains at least one collector that can filters document.
*/ */
static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector) public static TopDocsCollectorContext createTopDocsCollectorContext(SearchContext searchContext, boolean hasFilterCollector)
throws IOException { throws IOException {
final IndexReader reader = searchContext.searcher().getIndexReader(); final IndexReader reader = searchContext.searcher().getIndexReader();
final Query query = searchContext.query(); final Query query = searchContext.query();
@ -515,7 +821,7 @@ abstract class TopDocsCollectorContext extends QueryCollectorContext {
hasFilterCollector hasFilterCollector
) { ) {
@Override @Override
boolean shouldRescore() { public boolean shouldRescore() {
return rescore; return rescore;
} }
}; };

View File

@ -0,0 +1,106 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.search.query;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import java.io.IOException;
import java.util.Collection;
public class TotalHitCountCollectorManager
implements
CollectorManager<TotalHitCountCollector, ReduceableSearchResult>,
EarlyTerminatingListener {
private static final TotalHitCountCollector EMPTY_COLLECTOR = new TotalHitCountCollector() {
@Override
public void collect(int doc) {}
@Override
public ScoreMode scoreMode() {
return ScoreMode.COMPLETE_NO_SCORES;
}
};
private final Sort sort;
private Integer terminatedAfter;
public TotalHitCountCollectorManager(final Sort sort) {
this.sort = sort;
}
@Override
public void onEarlyTermination(int maxCountHits, boolean forcedTermination) {
terminatedAfter = maxCountHits;
}
@Override
public TotalHitCountCollector newCollector() throws IOException {
return new TotalHitCountCollector();
}
@Override
public ReduceableSearchResult reduce(Collection<TotalHitCountCollector> collectors) throws IOException {
return (QuerySearchResult result) -> {
final TotalHits.Relation relation = (terminatedAfter != null)
? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO
: TotalHits.Relation.EQUAL_TO;
int totalHits = collectors.stream().mapToInt(TotalHitCountCollector::getTotalHits).sum();
if (terminatedAfter != null && totalHits > terminatedAfter) {
totalHits = terminatedAfter;
}
final TotalHits totalHitCount = new TotalHits(totalHits, relation);
final TopDocs topDocs = (sort != null)
? new TopFieldDocs(totalHitCount, Lucene.EMPTY_SCORE_DOCS, sort.getSort())
: new TopDocs(totalHitCount, Lucene.EMPTY_SCORE_DOCS);
result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), null);
};
}
static class Empty implements CollectorManager<TotalHitCountCollector, ReduceableSearchResult> {
private final TotalHits totalHits;
private final Sort sort;
Empty(final TotalHits totalHits, final Sort sort) {
this.totalHits = totalHits;
this.sort = sort;
}
@Override
public TotalHitCountCollector newCollector() throws IOException {
return EMPTY_COLLECTOR;
}
@Override
public ReduceableSearchResult reduce(Collection<TotalHitCountCollector> collectors) throws IOException {
return (QuerySearchResult result) -> {
final TopDocs topDocs;
if (sort != null) {
topDocs = new TopFieldDocs(totalHits, Lucene.EMPTY_SCORE_DOCS, sort.getSort());
} else {
topDocs = new TopDocs(totalHits, Lucene.EMPTY_SCORE_DOCS);
}
result.topDocs(new TopDocsAndMaxScore(topDocs, Float.NaN), null);
};
}
}
}

View File

@ -32,6 +32,8 @@
package org.opensearch.search; package org.opensearch.search;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexReader;
import org.apache.lucene.tests.index.RandomIndexWriter; import org.apache.lucene.tests.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.IndexSearcher;
@ -76,7 +78,12 @@ import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPool;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.UUID; import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Function; import java.util.function.Function;
import java.util.function.Supplier; import java.util.function.Supplier;
@ -91,6 +98,25 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class DefaultSearchContextTests extends OpenSearchTestCase { public class DefaultSearchContextTests extends OpenSearchTestCase {
private final ExecutorService executor;
@ParametersFactory
public static Collection<Object[]> concurrency() {
return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 });
}
public DefaultSearchContextTests(int concurrency) {
this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null;
}
@Override
public void tearDown() throws Exception {
super.tearDown();
if (executor != null) {
ThreadPool.terminate(executor, 10, TimeUnit.SECONDS);
}
}
public void testPreProcess() throws Exception { public void testPreProcess() throws Exception {
TimeValue timeout = new TimeValue(randomIntBetween(1, 100)); TimeValue timeout = new TimeValue(randomIntBetween(1, 100));
@ -183,7 +209,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
contextWithoutScroll.from(300); contextWithoutScroll.from(300);
contextWithoutScroll.close(); contextWithoutScroll.close();
@ -225,7 +251,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
context1.from(300); context1.from(300);
exception = expectThrows(IllegalArgumentException.class, () -> context1.preProcess(false)); exception = expectThrows(IllegalArgumentException.class, () -> context1.preProcess(false));
@ -295,7 +321,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
SliceBuilder sliceBuilder = mock(SliceBuilder.class); SliceBuilder sliceBuilder = mock(SliceBuilder.class);
@ -334,7 +360,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
ParsedQuery parsedQuery = ParsedQuery.parsedMatchAllQuery(); ParsedQuery parsedQuery = ParsedQuery.parsedMatchAllQuery();
context3.sliceBuilder(null).parsedQuery(parsedQuery).preProcess(false); context3.sliceBuilder(null).parsedQuery(parsedQuery).preProcess(false);
@ -365,7 +391,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
context4.sliceBuilder(new SliceBuilder(1, 2)).parsedQuery(parsedQuery).preProcess(false); context4.sliceBuilder(new SliceBuilder(1, 2)).parsedQuery(parsedQuery).preProcess(false);
Query query1 = context4.query(); Query query1 = context4.query();
@ -446,7 +472,7 @@ public class DefaultSearchContextTests extends OpenSearchTestCase {
false, false,
Version.CURRENT, Version.CURRENT,
false, false,
null executor
); );
assertThat(context.searcher().hasCancellations(), is(false)); assertThat(context.searcher().hasCancellations(), is(false));
context.searcher().addQueryCancellation(() -> {}); context.searcher().addQueryCancellation(() -> {});

View File

@ -108,7 +108,8 @@ public class SearchCancellationTests extends OpenSearchTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
); );
NullPointerException npe = expectThrows(NullPointerException.class, () -> searcher.addQueryCancellation(null)); NullPointerException npe = expectThrows(NullPointerException.class, () -> searcher.addQueryCancellation(null));
assertEquals("cancellation runnable should not be null", npe.getMessage()); assertEquals("cancellation runnable should not be null", npe.getMessage());
@ -127,7 +128,8 @@ public class SearchCancellationTests extends OpenSearchTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
); );
searcher.search(new MatchAllDocsQuery(), collector1); searcher.search(new MatchAllDocsQuery(), collector1);
@ -154,7 +156,8 @@ public class SearchCancellationTests extends OpenSearchTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
); );
searcher.addQueryCancellation(cancellation); searcher.addQueryCancellation(cancellation);
CompiledAutomaton automaton = new CompiledAutomaton(new RegExp("a.*").toAutomaton()); CompiledAutomaton automaton = new CompiledAutomaton(new RegExp("a.*").toAutomaton());

View File

@ -258,7 +258,8 @@ public class ContextIndexSearcherTests extends OpenSearchTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
); );
for (LeafReaderContext context : searcher.getIndexReader().leaves()) { for (LeafReaderContext context : searcher.getIndexReader().leaves()) {

View File

@ -32,8 +32,6 @@
package org.opensearch.search.profile.query; package org.opensearch.search.profile.query;
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
import org.apache.lucene.document.Document; import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field.Store; import org.apache.lucene.document.Field.Store;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
@ -64,18 +62,12 @@ import org.opensearch.core.internal.io.IOUtils;
import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.profile.ProfileResult; import org.opensearch.search.profile.ProfileResult;
import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import java.io.IOException; import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThan;
@ -85,16 +77,6 @@ public class QueryProfilerTests extends OpenSearchTestCase {
private Directory dir; private Directory dir;
private IndexReader reader; private IndexReader reader;
private ContextIndexSearcher searcher; private ContextIndexSearcher searcher;
private ExecutorService executor;
@ParametersFactory
public static Collection<Object[]> concurrency() {
return Arrays.asList(new Integer[] { 0 }, new Integer[] { 5 });
}
public QueryProfilerTests(int concurrency) {
this.executor = (concurrency > 0) ? Executors.newFixedThreadPool(concurrency) : null;
}
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@ -120,7 +102,7 @@ public class QueryProfilerTests extends OpenSearchTestCase {
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
ALWAYS_CACHE_POLICY, ALWAYS_CACHE_POLICY,
true, true,
executor null
); );
} }
@ -134,10 +116,6 @@ public class QueryProfilerTests extends OpenSearchTestCase {
assertThat(cache.getTotalCount(), equalTo(cache.getMissCount())); assertThat(cache.getTotalCount(), equalTo(cache.getMissCount()));
assertThat(cache.getCacheSize(), equalTo(0L)); assertThat(cache.getCacheSize(), equalTo(0L));
if (executor != null) {
ThreadPool.terminate(executor, 10, TimeUnit.SECONDS);
}
IOUtils.close(reader, dir); IOUtils.close(reader, dir);
dir = null; dir = null;
reader = null; reader = null;
@ -145,7 +123,7 @@ public class QueryProfilerTests extends OpenSearchTestCase {
} }
public void testBasic() throws IOException { public void testBasic() throws IOException {
QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); QueryProfiler profiler = new QueryProfiler(false);
searcher.setProfiler(profiler); searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar")); Query query = new TermQuery(new Term("foo", "bar"));
searcher.search(query, 1); searcher.search(query, 1);
@ -171,7 +149,7 @@ public class QueryProfilerTests extends OpenSearchTestCase {
} }
public void testNoScoring() throws IOException { public void testNoScoring() throws IOException {
QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); QueryProfiler profiler = new QueryProfiler(false);
searcher.setProfiler(profiler); searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar")); Query query = new TermQuery(new Term("foo", "bar"));
searcher.search(query, 1, Sort.INDEXORDER); // scores are not needed searcher.search(query, 1, Sort.INDEXORDER); // scores are not needed
@ -197,7 +175,7 @@ public class QueryProfilerTests extends OpenSearchTestCase {
} }
public void testUseIndexStats() throws IOException { public void testUseIndexStats() throws IOException {
QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); QueryProfiler profiler = new QueryProfiler(false);
searcher.setProfiler(profiler); searcher.setProfiler(profiler);
Query query = new TermQuery(new Term("foo", "bar")); Query query = new TermQuery(new Term("foo", "bar"));
searcher.count(query); // will use index stats searcher.count(query); // will use index stats
@ -211,7 +189,7 @@ public class QueryProfilerTests extends OpenSearchTestCase {
} }
public void testApproximations() throws IOException { public void testApproximations() throws IOException {
QueryProfiler profiler = new QueryProfiler(searcher.allowConcurrentSegmentSearch()); QueryProfiler profiler = new QueryProfiler(false);
searcher.setProfiler(profiler); searcher.setProfiler(profiler);
Query query = new RandomApproximationQuery(new TermQuery(new Term("foo", "bar")), random()); Query query = new RandomApproximationQuery(new TermQuery(new Term("foo", "bar")), random());
searcher.count(query); searcher.count(query);

View File

@ -39,6 +39,7 @@ import org.apache.lucene.document.LatLonDocValuesField;
import org.apache.lucene.document.LatLonPoint; import org.apache.lucene.document.LatLonPoint;
import org.apache.lucene.document.LongPoint; import org.apache.lucene.document.LongPoint;
import org.apache.lucene.document.NumericDocValuesField; import org.apache.lucene.document.NumericDocValuesField;
import org.apache.lucene.document.SortedDocValuesField;
import org.apache.lucene.document.SortedSetDocValuesField; import org.apache.lucene.document.SortedSetDocValuesField;
import org.apache.lucene.document.StringField; import org.apache.lucene.document.StringField;
import org.apache.lucene.document.TextField; import org.apache.lucene.document.TextField;
@ -77,6 +78,7 @@ import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHitCountCollector; import org.apache.lucene.search.TotalHitCountCollector;
import org.apache.lucene.search.TotalHits; import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight; import org.apache.lucene.search.Weight;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ScoreMode; import org.apache.lucene.search.join.ScoreMode;
import org.apache.lucene.store.Directory; import org.apache.lucene.store.Directory;
@ -88,12 +90,15 @@ import org.opensearch.index.mapper.DateFieldMapper;
import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.MapperService;
import org.opensearch.index.mapper.NumberFieldMapper; import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType;
import org.opensearch.index.mapper.NumberFieldMapper.NumberType;
import org.opensearch.index.query.ParsedQuery; import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryShardContext; import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery; import org.opensearch.index.search.OpenSearchToParentBlockJoinQuery;
import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardTestCase; import org.opensearch.index.shard.IndexShardTestCase;
import org.opensearch.search.DocValueFormat; import org.opensearch.search.DocValueFormat;
import org.opensearch.search.collapse.CollapseBuilder;
import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.ScrollContext; import org.opensearch.search.internal.ScrollContext;
import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.SearchContext;
@ -144,7 +149,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.parsedQuery(new ParsedQuery(query)); context.parsedQuery(new ParsedQuery(query));
context.setSize(0); context.setSize(0);
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
final boolean rescore = QueryPhase.executeInternal(context); final boolean rescore = QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(rescore); assertFalse(rescore);
ContextIndexSearcher countSearcher = shouldCollectCount ContextIndexSearcher countSearcher = shouldCollectCount
@ -157,7 +162,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriterConfig iwc = newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE); IndexWriterConfig iwc = newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
Document doc = new Document(); Document doc = new Document();
if (randomBoolean()) { if (randomBoolean()) {
@ -228,12 +233,12 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value);
context.setSearcher(newContextSearcher(reader)); context.setSearcher(newContextSearcher(reader));
context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery())); context.parsedPostFilter(new ParsedQuery(new MatchNoDocsQuery()));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value);
reader.close(); reader.close();
dir.close(); dir.close();
@ -261,7 +266,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setSize(10); context.setSize(10);
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i))))); context.parsedPostFilter(new ParsedQuery(new TermQuery(new Term("foo", Integer.toString(i)))));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value);
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
} }
@ -283,12 +288,13 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
context.setSize(0); context.setSize(0);
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(1, context.queryResult().topDocs().topDocs.totalHits.value);
context.minimumScore(100); context.minimumScore(100);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(0, context.queryResult().topDocs().topDocs.totalHits.value);
assertEquals(TotalHits.Relation.EQUAL_TO, context.queryResult().topDocs().topDocs.totalHits.relation);
reader.close(); reader.close();
dir.close(); dir.close();
} }
@ -297,7 +303,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriterConfig iwc = newIndexWriterConfig(); IndexWriterConfig iwc = newIndexWriterConfig();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
w.addDocument(new Document()); w.addDocument(new Document());
} }
@ -307,7 +313,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery()));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
QuerySearchResult results = context.queryResult(); QuerySearchResult results = context.queryResult();
assertThat(results.serviceTimeEWMA(), greaterThanOrEqualTo(0L)); assertThat(results.serviceTimeEWMA(), greaterThanOrEqualTo(0L));
assertThat(results.nodeQueueSize(), greaterThanOrEqualTo(0)); assertThat(results.nodeQueueSize(), greaterThanOrEqualTo(0));
@ -320,7 +326,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); final Sort sort = new Sort(new SortField("rank", SortField.Type.INT));
IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
w.addDocument(new Document()); w.addDocument(new Document());
} }
@ -336,14 +342,14 @@ public class QueryPhaseTests extends IndexShardTestCase {
int size = randomIntBetween(2, 5); int size = randomIntBetween(2, 5);
context.setSize(size); context.setSize(size);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.terminateAfter(), equalTo(0)); assertThat(context.terminateAfter(), equalTo(0));
assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs));
context.setSearcher(newEarlyTerminationContextSearcher(reader, size)); context.setSearcher(newEarlyTerminationContextSearcher(reader, size));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.terminateAfter(), equalTo(size)); assertThat(context.terminateAfter(), equalTo(size));
assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs)); assertThat(context.queryResult().getTotalHits().value, equalTo((long) numDocs));
@ -356,7 +362,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
Directory dir = newDirectory(); Directory dir = newDirectory();
IndexWriterConfig iwc = newIndexWriterConfig(); IndexWriterConfig iwc = newIndexWriterConfig();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
Document doc = new Document(); Document doc = new Document();
if (randomBoolean()) { if (randomBoolean()) {
@ -377,25 +383,25 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.terminateAfter(numDocs); context.terminateAfter(numDocs);
{ {
context.setSize(10); context.setSize(10);
TotalHitCountCollector collector = new TotalHitCountCollector(); final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create();
context.queryCollectors().put(TotalHitCountCollector.class, collector); context.queryCollectorManagers().put(TotalHitCountCollector.class, manager);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(context.queryResult().terminatedEarly()); assertFalse(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(10)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(10));
assertThat(collector.getTotalHits(), equalTo(numDocs)); assertThat(manager.getTotalHits(), equalTo(numDocs));
} }
context.terminateAfter(1); context.terminateAfter(1);
{ {
context.setSize(1); context.setSize(1);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
context.setSize(0); context.setSize(0);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
@ -403,7 +409,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
{ {
context.setSize(1); context.setSize(1);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
@ -414,38 +420,38 @@ public class QueryPhaseTests extends IndexShardTestCase {
.add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD) .add(new TermQuery(new Term("foo", "baz")), Occur.SHOULD)
.build(); .build();
context.parsedQuery(new ParsedQuery(bq)); context.parsedQuery(new ParsedQuery(bq));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
context.setSize(0); context.setSize(0);
context.parsedQuery(new ParsedQuery(bq)); context.parsedQuery(new ParsedQuery(bq));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
} }
{ {
context.setSize(1); context.setSize(1);
TotalHitCountCollector collector = new TotalHitCountCollector(); final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create();
context.queryCollectors().put(TotalHitCountCollector.class, collector); context.queryCollectorManagers().put(TotalHitCountCollector.class, manager);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
assertThat(collector.getTotalHits(), equalTo(1)); assertThat(manager.getTotalHits(), equalTo(1));
context.queryCollectors().clear(); context.queryCollectorManagers().clear();
} }
{ {
context.setSize(0); context.setSize(0);
TotalHitCountCollector collector = new TotalHitCountCollector(); final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create();
context.queryCollectors().put(TotalHitCountCollector.class, collector); context.queryCollectorManagers().put(TotalHitCountCollector.class, manager);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
assertThat(collector.getTotalHits(), equalTo(1)); assertThat(manager.getTotalHits(), equalTo(1));
} }
// tests with trackTotalHits and terminateAfter // tests with trackTotalHits and terminateAfter
@ -453,9 +459,9 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setSize(0); context.setSize(0);
for (int trackTotalHits : new int[] { -1, 3, 76, 100 }) { for (int trackTotalHits : new int[] { -1, 3, 76, 100 }) {
context.trackTotalHitsUpTo(trackTotalHits); context.trackTotalHitsUpTo(trackTotalHits);
TotalHitCountCollector collector = new TotalHitCountCollector(); final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create();
context.queryCollectors().put(TotalHitCountCollector.class, collector); context.queryCollectorManagers().put(TotalHitCountCollector.class, manager);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
if (trackTotalHits == -1) { if (trackTotalHits == -1) {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L));
@ -463,16 +469,14 @@ public class QueryPhaseTests extends IndexShardTestCase {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10))); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) Math.min(trackTotalHits, 10)));
} }
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(0));
assertThat(collector.getTotalHits(), equalTo(10)); assertThat(manager.getTotalHits(), equalTo(10));
} }
context.terminateAfter(7); context.terminateAfter(7);
context.setSize(10); context.setSize(10);
for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) { for (int trackTotalHits : new int[] { -1, 3, 75, 100 }) {
context.trackTotalHitsUpTo(trackTotalHits); context.trackTotalHitsUpTo(trackTotalHits);
EarlyTerminatingCollector collector = new EarlyTerminatingCollector(new TotalHitCountCollector(), 1, false); QueryPhase.executeInternal(context.withCleanQueryResult());
context.queryCollectors().put(EarlyTerminatingCollector.class, collector);
QueryPhase.executeInternal(context);
assertTrue(context.queryResult().terminatedEarly()); assertTrue(context.queryResult().terminatedEarly());
if (trackTotalHits == -1) { if (trackTotalHits == -1) {
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(0L));
@ -490,7 +494,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
final Sort sort = new Sort(new SortField("rank", SortField.Type.INT)); final Sort sort = new Sort(new SortField("rank", SortField.Type.INT));
IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
Document doc = new Document(); Document doc = new Document();
if (randomBoolean()) { if (randomBoolean()) {
@ -511,7 +515,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap())); context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW })); context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW }));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class));
@ -520,7 +524,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
{ {
context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1))); context.parsedPostFilter(new ParsedQuery(new MinDocQuery(1)));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo(numDocs - 1L));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
@ -528,28 +532,28 @@ public class QueryPhaseTests extends IndexShardTestCase {
assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2)));
context.parsedPostFilter(null); context.parsedPostFilter(null);
final TotalHitCountCollector totalHitCountCollector = new TotalHitCountCollector(); final TestTotalHitCountCollectorManager manager = TestTotalHitCountCollectorManager.create(sort);
context.queryCollectors().put(TotalHitCountCollector.class, totalHitCountCollector); context.queryCollectorManagers().put(TotalHitCountCollector.class, manager);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class));
assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2)));
assertThat(totalHitCountCollector.getTotalHits(), equalTo(numDocs)); assertThat(manager.getTotalHits(), equalTo(numDocs));
context.queryCollectors().clear(); context.queryCollectorManagers().clear();
} }
{ {
context.setSearcher(newEarlyTerminationContextSearcher(reader, 1)); context.setSearcher(newEarlyTerminationContextSearcher(reader, 1));
context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED); context.trackTotalHitsUpTo(SearchContext.TRACK_TOTAL_HITS_DISABLED);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class));
assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2))); assertThat(fieldDoc.fields[0], anyOf(equalTo(1), equalTo(2)));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(1));
assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs[0], instanceOf(FieldDoc.class));
@ -564,7 +568,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT)); final Sort indexSort = new Sort(new SortField("rank", SortField.Type.INT), new SortField("tiebreaker", SortField.Type.INT));
IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort); IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(indexSort);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(100, 200); final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; ++i) { for (int i = 0; i < numDocs; ++i) {
Document doc = new Document(); Document doc = new Document();
doc.add(new NumericDocValuesField("rank", random().nextInt())); doc.add(new NumericDocValuesField("rank", random().nextInt()));
@ -592,7 +596,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setSize(10); context.setSize(10);
context.sort(searchSortAndFormat); context.sort(searchSortAndFormat);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.terminateAfter(), equalTo(0)); assertThat(context.terminateAfter(), equalTo(0));
@ -601,7 +605,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1]; FieldDoc lastDoc = (FieldDoc) context.queryResult().topDocs().topDocs.scoreDocs[sizeMinus1];
context.setSearcher(newEarlyTerminationContextSearcher(reader, 10)); context.setSearcher(newEarlyTerminationContextSearcher(reader, 10));
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertNull(context.queryResult().terminatedEarly()); assertNull(context.queryResult().terminatedEarly());
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs)); assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.terminateAfter(), equalTo(0)); assertThat(context.terminateAfter(), equalTo(0));
@ -630,7 +634,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer());
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc); RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
Document doc = new Document(); Document doc = new Document();
for (int i = 0; i < 10; i++) { final int numDocs = 2 * scaledRandomIntBetween(50, 450);
for (int i = 0; i < numDocs; i++) {
doc.clear(); doc.clear();
if (i % 2 == 0) { if (i % 2 == 0) {
doc.add(new TextField("title", "foo bar", Store.NO)); doc.add(new TextField("title", "foo bar", Store.NO));
@ -653,16 +658,16 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.trackTotalHitsUpTo(3); context.trackTotalHitsUpTo(3);
TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); TopDocsCollectorContext topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false);
assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE); assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.COMPLETE);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value);
assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.EQUAL_TO);
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3));
context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW })); context.sort(new SortAndFormats(new Sort(new SortField("other", SortField.Type.INT)), new DocValueFormat[] { DocValueFormat.RAW }));
topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false); topDocsContext = TopDocsCollectorContext.createTopDocsCollectorContext(context, false);
assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS); assertEquals(topDocsContext.create(null).scoreMode(), org.apache.lucene.search.ScoreMode.TOP_DOCS);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(5, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(numDocs / 2, context.queryResult().topDocs().topDocs.totalHits.value);
assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3)); assertThat(context.queryResult().topDocs().topDocs.scoreDocs.length, equalTo(3));
assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO); assertEquals(context.queryResult().topDocs().topDocs.totalHits.relation, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO);
@ -724,7 +729,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
searchContext.parsedQuery(query); searchContext.parsedQuery(query);
searchContext.setTask(task); searchContext.setTask(task);
searchContext.setSize(10); searchContext.setSize(10);
QueryPhase.executeInternal(searchContext); QueryPhase.executeInternal(searchContext.withCleanQueryResult());
assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false);
} }
@ -736,7 +741,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
searchContext.parsedQuery(query); searchContext.parsedQuery(query);
searchContext.setTask(task); searchContext.setTask(task);
searchContext.setSize(10); searchContext.setSize(10);
QueryPhase.executeInternal(searchContext); QueryPhase.executeInternal(searchContext.withCleanQueryResult());
assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, true); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, true);
} }
@ -748,7 +753,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
searchContext.parsedQuery(query); searchContext.parsedQuery(query);
searchContext.setTask(task); searchContext.setTask(task);
searchContext.setSize(10); searchContext.setSize(10);
QueryPhase.executeInternal(searchContext); QueryPhase.executeInternal(searchContext.withCleanQueryResult());
assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false);
} }
@ -773,7 +778,7 @@ public class QueryPhaseTests extends IndexShardTestCase {
searchContext.setTask(task); searchContext.setTask(task);
searchContext.from(5); searchContext.from(5);
searchContext.setSize(0); searchContext.setSize(0);
QueryPhase.executeInternal(searchContext); QueryPhase.executeInternal(searchContext.withCleanQueryResult());
assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false); assertSortResults(searchContext.queryResult().topDocs().topDocs, (long) numDocs, false);
} }
@ -800,11 +805,15 @@ public class QueryPhaseTests extends IndexShardTestCase {
searchContext.parsedQuery(query); searchContext.parsedQuery(query);
searchContext.setTask(task); searchContext.setTask(task);
searchContext.setSize(10); searchContext.setSize(10);
QueryPhase.executeInternal(searchContext); QueryPhase.executeInternal(searchContext.withCleanQueryResult());
final TopDocs topDocs = searchContext.queryResult().topDocs().topDocs; final TopDocs topDocs = searchContext.queryResult().topDocs().topDocs;
long topValue = (long) ((FieldDoc) topDocs.scoreDocs[0]).fields[0]; long topValue = (long) ((FieldDoc) topDocs.scoreDocs[0]).fields[0];
assertThat(topValue, greaterThan(afterValue)); assertThat(topValue, greaterThan(afterValue));
assertSortResults(topDocs, (long) numDocs, false); assertSortResults(topDocs, (long) numDocs, false);
final TotalHits totalHits = topDocs.totalHits;
assertEquals(TotalHits.Relation.EQUAL_TO, totalHits.relation);
assertEquals(numDocs, totalHits.value);
} }
reader.close(); reader.close();
@ -916,13 +925,133 @@ public class QueryPhaseTests extends IndexShardTestCase {
context.setSize(1); context.setSize(1);
context.trackTotalHitsUpTo(5); context.trackTotalHitsUpTo(5);
QueryPhase.executeInternal(context); QueryPhase.executeInternal(context.withCleanQueryResult());
assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value); assertEquals(10, context.queryResult().topDocs().topDocs.totalHits.value);
reader.close(); reader.close();
dir.close(); dir.close();
} }
public void testMaxScore() throws Exception {
Directory dir = newDirectory();
final Sort sort = new Sort(new SortField("filter", SortField.Type.STRING));
IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
final int numDocs = scaledRandomIntBetween(600, 900);
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new StringField("foo", "bar", Store.NO));
doc.add(new StringField("filter", "f1" + ((i > 0) ? " " + Integer.toString(i) : ""), Store.NO));
doc.add(new SortedDocValuesField("filter", newBytesRef("f1" + ((i > 0) ? " " + Integer.toString(i) : ""))));
w.addDocument(doc);
}
w.close();
IndexReader reader = DirectoryReader.open(dir);
TestSearchContext context = new TestSearchContext(null, indexShard, newContextSearcher(reader));
context.trackScores(true);
context.parsedQuery(
new ParsedQuery(
new BooleanQuery.Builder().add(new TermQuery(new Term("foo", "bar")), Occur.MUST)
.add(new TermQuery(new Term("filter", "f1")), Occur.SHOULD)
.build()
)
);
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.setSize(1);
context.trackTotalHitsUpTo(5);
QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L));
context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW }));
QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L));
context.trackScores(false);
QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(1, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, greaterThanOrEqualTo(6L));
reader.close();
dir.close();
}
public void testCollapseQuerySearchResults() throws Exception {
Directory dir = newDirectory();
final Sort sort = new Sort(new SortField("user", SortField.Type.INT));
IndexWriterConfig iwc = newIndexWriterConfig().setIndexSort(sort);
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
// Always end up with uneven buckets so collapsing is predictable
final int numDocs = 2 * scaledRandomIntBetween(600, 900) - 1;
for (int i = 0; i < numDocs; i++) {
Document doc = new Document();
doc.add(new StringField("foo", "bar", Store.NO));
doc.add(new NumericDocValuesField("user", i & 1));
w.addDocument(doc);
}
w.close();
IndexReader reader = DirectoryReader.open(dir);
QueryShardContext queryShardContext = mock(QueryShardContext.class);
when(queryShardContext.fieldMapper("user")).thenReturn(
new NumberFieldType("user", NumberType.INTEGER, true, false, true, false, null, Collections.emptyMap())
);
TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader));
context.collapse(new CollapseBuilder("user").build(context.getQueryShardContext()));
context.trackScores(true);
context.parsedQuery(new ParsedQuery(new TermQuery(new Term("foo", "bar"))));
context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.setSize(2);
context.trackTotalHitsUpTo(5);
QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class));
CollapseTopFieldDocs topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs;
assertThat(topDocs.collapseValues.length, equalTo(2));
assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0
assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1
context.sort(new SortAndFormats(sort, new DocValueFormat[] { DocValueFormat.RAW }));
QueryPhase.executeInternal(context.withCleanQueryResult());
assertFalse(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class));
topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs;
assertThat(topDocs.collapseValues.length, equalTo(2));
assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0
assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1
context.trackScores(false);
QueryPhase.executeInternal(context.withCleanQueryResult());
assertTrue(Float.isNaN(context.queryResult().getMaxScore()));
assertEquals(2, context.queryResult().topDocs().topDocs.scoreDocs.length);
assertThat(context.queryResult().topDocs().topDocs.totalHits.value, equalTo((long) numDocs));
assertThat(context.queryResult().topDocs().topDocs, instanceOf(CollapseTopFieldDocs.class));
topDocs = (CollapseTopFieldDocs) context.queryResult().topDocs().topDocs;
assertThat(topDocs.collapseValues.length, equalTo(2));
assertThat(topDocs.collapseValues[0], equalTo(0L)); // user == 0
assertThat(topDocs.collapseValues[1], equalTo(1L)); // user == 1
reader.close();
dir.close();
}
public void testCancellationDuringPreprocess() throws IOException { public void testCancellationDuringPreprocess() throws IOException {
try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) { try (Directory dir = newDirectory(); RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
@ -982,7 +1111,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
); );
} }
@ -992,7 +1122,8 @@ public class QueryPhaseTests extends IndexShardTestCase {
IndexSearcher.getDefaultSimilarity(), IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(), IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(), IndexSearcher.getDefaultQueryCachingPolicy(),
true true,
null
) { ) {
@Override @Override
@ -1003,6 +1134,32 @@ public class QueryPhaseTests extends IndexShardTestCase {
}; };
} }
private static class TestTotalHitCountCollectorManager extends TotalHitCountCollectorManager {
private final TotalHitCountCollector collector;
static TestTotalHitCountCollectorManager create() {
return create(null);
}
static TestTotalHitCountCollectorManager create(final Sort sort) {
return new TestTotalHitCountCollectorManager(new TotalHitCountCollector(), sort);
}
private TestTotalHitCountCollectorManager(final TotalHitCountCollector collector, final Sort sort) {
super(sort);
this.collector = collector;
}
@Override
public TotalHitCountCollector newCollector() throws IOException {
return collector;
}
public int getTotalHits() {
return collector.getTotalHits();
}
}
private static class AssertingEarlyTerminationFilterCollector extends FilterCollector { private static class AssertingEarlyTerminationFilterCollector extends FilterCollector {
private final int size; private final int size;

File diff suppressed because it is too large Load Diff

View File

@ -334,7 +334,8 @@ public abstract class AggregatorTestCase extends OpenSearchTestCase {
indexSearcher.getSimilarity(), indexSearcher.getSimilarity(),
queryCache, queryCache,
queryCachingPolicy, queryCachingPolicy,
false false,
null
); );
SearchContext searchContext = mock(SearchContext.class); SearchContext searchContext = mock(SearchContext.class);

View File

@ -32,6 +32,7 @@
package org.opensearch.test; package org.opensearch.test;
import org.apache.lucene.search.Collector; import org.apache.lucene.search.Collector;
import org.apache.lucene.search.CollectorManager;
import org.apache.lucene.search.FieldDoc; import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.Query; import org.apache.lucene.search.Query;
import org.opensearch.action.OriginalIndices; import org.opensearch.action.OriginalIndices;
@ -70,6 +71,7 @@ import org.opensearch.search.internal.ShardSearchContextId;
import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.internal.ShardSearchRequest;
import org.opensearch.search.profile.Profilers; import org.opensearch.search.profile.Profilers;
import org.opensearch.search.query.QuerySearchResult; import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.query.ReduceableSearchResult;
import org.opensearch.search.rescore.RescoreContext; import org.opensearch.search.rescore.RescoreContext;
import org.opensearch.search.sort.SortAndFormats; import org.opensearch.search.sort.SortAndFormats;
import org.opensearch.search.suggest.SuggestionSearchContext; import org.opensearch.search.suggest.SuggestionSearchContext;
@ -90,7 +92,7 @@ public class TestSearchContext extends SearchContext {
final BigArrays bigArrays; final BigArrays bigArrays;
final IndexService indexService; final IndexService indexService;
final BitsetFilterCache fixedBitSetFilterCache; final BitsetFilterCache fixedBitSetFilterCache;
final Map<Class<?>, Collector> queryCollectors = new HashMap<>(); final Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers = new HashMap<>();
final IndexShard indexShard; final IndexShard indexShard;
final QuerySearchResult queryResult = new QuerySearchResult(); final QuerySearchResult queryResult = new QuerySearchResult();
final QueryShardContext queryShardContext; final QueryShardContext queryShardContext;
@ -110,7 +112,9 @@ public class TestSearchContext extends SearchContext {
private SearchContextAggregations aggregations; private SearchContextAggregations aggregations;
private ScrollContext scrollContext; private ScrollContext scrollContext;
private FieldDoc searchAfter; private FieldDoc searchAfter;
private final long originNanoTime = System.nanoTime(); private Profilers profilers;
private CollapseContext collapse;
private final Map<String, SearchExtBuilder> searchExtBuilders = new HashMap<>(); private final Map<String, SearchExtBuilder> searchExtBuilders = new HashMap<>();
public TestSearchContext(BigArrays bigArrays, IndexService indexService) { public TestSearchContext(BigArrays bigArrays, IndexService indexService) {
@ -405,12 +409,13 @@ public class TestSearchContext extends SearchContext {
@Override @Override
public SearchContext collapse(CollapseContext collapse) { public SearchContext collapse(CollapseContext collapse) {
return null; this.collapse = collapse;
return this;
} }
@Override @Override
public CollapseContext collapse() { public CollapseContext collapse() {
return null; return collapse;
} }
@Override @Override
@ -596,12 +601,12 @@ public class TestSearchContext extends SearchContext {
@Override @Override
public Profilers getProfilers() { public Profilers getProfilers() {
return null; // no profiling return profilers;
} }
@Override @Override
public Map<Class<?>, Collector> queryCollectors() { public Map<Class<?>, CollectorManager<? extends Collector, ReduceableSearchResult>> queryCollectorManagers() {
return queryCollectors; return queryCollectorManagers;
} }
@Override @Override
@ -633,4 +638,21 @@ public class TestSearchContext extends SearchContext {
public ReaderContext readerContext() { public ReaderContext readerContext() {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();
} }
/**
* Clean the query results by consuming all of it
*/
public TestSearchContext withCleanQueryResult() {
queryResult.consumeAll();
profilers = null;
return this;
}
/**
* Add profilers to the query
*/
public TestSearchContext withProfilers() {
this.profilers = new Profilers(searcher);
return this;
}
} }