Check for query cancellation during rewrite (#53166) (#53203)

With ExitableDirectoryReader in place, check for query cancellation
during QueryPhase#preProcess where the query rewriting takes place.

Follows: #52822

(cherry picked from commit 0d38626d8e6e9e2620a7a446b617a2ac42852461)
This commit is contained in:
Marios Trivyzas 2020-03-06 11:04:01 +01:00 committed by GitHub
parent c204137451
commit 7ddbda4c20
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 1 deletions

View File

@ -110,7 +110,24 @@ public class QueryPhase implements SearchPhase {
@Override
public void preProcess(SearchContext context) {
final Runnable cancellation;
if (context.lowLevelCancellation()) {
SearchShardTask task = context.getTask();
cancellation = context.searcher().addQueryCancellation(() -> {
if (task.isCancelled()) {
throw new TaskCancelledException("cancelled");
}
});
} else {
cancellation = null;
}
try {
context.preProcess(true);
} finally {
if (cancellation != null) {
context.searcher().removeQueryCancellation(cancellation);
}
}
}
@Override

View File

@ -53,6 +53,8 @@ import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.MultiTermQuery;
import org.apache.lucene.search.PrefixQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.Sort;
@ -76,6 +78,7 @@ import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperService;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.index.query.ParsedQuery;
import org.elasticsearch.index.query.QueryShardContext;
import org.elasticsearch.index.search.ESToParentBlockJoinQuery;
import org.elasticsearch.index.shard.IndexShard;
import org.elasticsearch.index.shard.IndexShardTestCase;
@ -84,6 +87,7 @@ import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.ScrollContext;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.sort.SortAndFormats;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.test.TestSearchContext;
import java.io.IOException;
@ -834,7 +838,59 @@ public class QueryPhaseTests extends IndexShardTestCase {
reader.close();
dir.close();
}
public void testCancellationDuringPreprocess() throws IOException {
try (Directory dir = newDirectory();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, newIndexWriterConfig())) {
for (int i = 0; i < 10; i++) {
Document doc = new Document();
StringBuilder sb = new StringBuilder();
for (int j = 0; j < i; j++) {
sb.append('a');
}
doc.add(new StringField("foo", sb.toString(), Store.NO));
w.addDocument(doc);
}
w.flush();
w.close();
try (IndexReader reader = DirectoryReader.open(dir)) {
TestSearchContext context = new TestSearchContextWithRewriteAndCancellation(
null, indexShard, newContextSearcher(reader));
PrefixQuery prefixQuery = new PrefixQuery(new Term("foo", "a"));
prefixQuery.setRewriteMethod(MultiTermQuery.SCORING_BOOLEAN_REWRITE);
context.parsedQuery(new ParsedQuery(prefixQuery));
SearchShardTask task = mock(SearchShardTask.class);
when(task.isCancelled()).thenReturn(true);
context.setTask(task);
expectThrows(TaskCancelledException.class, () -> new QueryPhase().preProcess(context));
}
}
}
private static class TestSearchContextWithRewriteAndCancellation extends TestSearchContext {
private TestSearchContextWithRewriteAndCancellation(QueryShardContext queryShardContext,
IndexShard indexShard,
ContextIndexSearcher searcher) {
super(queryShardContext, indexShard, searcher);
}
@Override
public void preProcess(boolean rewrite) {
try {
searcher().rewrite(query());
} catch (IOException e) {
fail("IOException shouldn't be thrown");
}
}
@Override
public boolean lowLevelCancellation() {
return true;
}
}
private static ContextIndexSearcher newContextSearcher(IndexReader reader) throws IOException {