Part 1: Support for cancel_after_timeinterval parameter in search and msearch request (#986)

* Part 1: Support for cancel_after_timeinterval parameter in search and msearch request

This commit introduces the new request level parameter to configure the timeout interval after which
a search request will be cancelled. For msearch request the parameter is supported both at parent
request and at sub child search requests. If it is provided at parent level and child search request
doesn't have it then the parent level value is set at such child request. The parent level msearch
is not used to cancel the parent request as it may be tricky to come up with correct value in cases
when child search request can have different runtimes

TEST: Added test for ser/de with new parameter

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Part 2: Support for cancel_after_timeinterval parameter in search and msearch request

This commit adds the handling of the new request level parameter and schedule cancellation task. It
also adds a cluster setting to set a global cancellation timeout for search request which will be
used in absence of request level timeout.

TEST: Added new tests in SearchCancellationIT
Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Address Review feedback for Part 1

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Address review feedback for Part 2

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Update CancellableTask to remove the cancelOnTimeout boolean flag

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Replace search.cancellation.timeout cluster setting with search.enforce_server.timeout.cancellation to control if cluster level cancel_after_time_interval should take precedence over request level cancel_after_time_interval value

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

* Removing the search.enforce_server.timeout.cancellation cluster setting and just keeping search.cancel_after_time_interval setting with request level parameter taking the precedence.

Signed-off-by: Sorabh Hamirwasia <sohami.apache@gmail.com>

Co-authored-by: Sorabh Hamirwasia <hsorabh@amazon.com>
This commit is contained in:
Sorabh 2021-08-12 08:01:28 -07:00 committed by GitHub
parent 5cd085c713
commit 9b6e621452
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 590 additions and 17 deletions

View File

@ -34,6 +34,7 @@ package org.opensearch.search;
import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.LogManager;
import org.junit.After;
import org.opensearch.ExceptionsHelper; import org.opensearch.ExceptionsHelper;
import org.opensearch.action.ActionFuture; import org.opensearch.action.ActionFuture;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse; import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksResponse;
@ -59,18 +60,24 @@ import org.opensearch.search.lookup.LeafFieldsLookup;
import org.opensearch.tasks.TaskCancelledException; import org.opensearch.tasks.TaskCancelledException;
import org.opensearch.tasks.TaskInfo; import org.opensearch.tasks.TaskInfo;
import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.transport.TransportException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function; import java.util.function.Function;
import static org.opensearch.action.search.TransportSearchAction.SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY;
import static org.opensearch.index.query.QueryBuilders.scriptQuery; import static org.opensearch.index.query.QueryBuilders.scriptQuery;
import static org.opensearch.search.SearchCancellationIT.ScriptedBlockPlugin.SCRIPT_NAME; import static org.opensearch.search.SearchCancellationIT.ScriptedBlockPlugin.SCRIPT_NAME;
import static org.opensearch.search.SearchService.NO_TIMEOUT;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertFailures; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertFailures;
import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertNoFailures;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
@ -96,6 +103,11 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
.build(); .build();
} }
@After
public void cleanup() {
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder().putNull("*")).get();
}
private void indexTestData() { private void indexTestData() {
for (int i = 0; i < 5; i++) { for (int i = 0; i < 5; i++) {
// Make sure we have a few segments // Make sure we have a few segments
@ -153,15 +165,51 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
SearchResponse response = searchResponse.actionGet(); SearchResponse response = searchResponse.actionGet();
logger.info("Search response {}", response); logger.info("Search response {}", response);
assertNotEquals("At least one shard should have failed", 0, response.getFailedShards()); assertNotEquals("At least one shard should have failed", 0, response.getFailedShards());
verifyCancellationException(response.getShardFailures());
return response; return response;
} catch (SearchPhaseExecutionException ex) { } catch (SearchPhaseExecutionException ex) {
logger.info("All shards failed with", ex); logger.info("All shards failed with", ex);
verifyCancellationException(ex.shardFailures());
return null; return null;
} }
} }
public void testCancellationDuringQueryPhase() throws Exception { private void ensureMSearchWasCancelled(ActionFuture<MultiSearchResponse> mSearchResponse,
Set<Integer> expectedFailedChildRequests) {
MultiSearchResponse response = mSearchResponse.actionGet();
Set<Integer> actualFailedChildRequests = new HashSet<>();
for (int i = 0; i < response.getResponses().length; ++i) {
SearchResponse sResponse = response.getResponses()[i].getResponse();
// check if response is null means all the shard failed for this search request
if (sResponse == null) {
Exception ex = response.getResponses()[i].getFailure();
assertTrue(ex instanceof SearchPhaseExecutionException);
verifyCancellationException(((SearchPhaseExecutionException)ex).shardFailures());
actualFailedChildRequests.add(i);
} else if (sResponse.getShardFailures().length > 0) {
verifyCancellationException(sResponse.getShardFailures());
actualFailedChildRequests.add(i);
}
}
assertEquals("Actual child request with cancellation failure is different that expected", expectedFailedChildRequests,
actualFailedChildRequests);
}
private void verifyCancellationException(ShardSearchFailure[] failures) {
for (ShardSearchFailure searchFailure : failures) {
// failure may happen while executing the search or while sending shard request for next phase.
// Below assertion is handling both the cases
final Throwable topFailureCause = searchFailure.getCause();
assertTrue(searchFailure.toString(), topFailureCause instanceof TransportException ||
topFailureCause instanceof TaskCancelledException);
if (topFailureCause instanceof TransportException) {
assertTrue(topFailureCause.getCause() instanceof TaskCancelledException);
}
}
}
public void testCancellationDuringQueryPhase() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData(); indexTestData();
@ -178,8 +226,49 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
ensureSearchWasCancelled(searchResponse); ensureSearchWasCancelled(searchResponse);
} }
public void testCancellationDuringFetchPhase() throws Exception { public void testCancellationDuringQueryPhaseUsingRequestParameter() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setCancelAfterTimeInterval(cancellationTimeout)
.setAllowPartialSearchResults(randomBoolean())
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME, Collections.emptyMap())))
.execute();
awaitForBlock(plugins);
// sleep for cancellation timeout to ensure scheduled cancellation task is actually executed
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
ensureSearchWasCancelled(searchResponse);
}
public void testCancellationDuringQueryPhaseUsingClusterSetting() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder()
.put(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, cancellationTimeout)
.build()).get();
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setAllowPartialSearchResults(randomBoolean())
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME, Collections.emptyMap())))
.execute();
awaitForBlock(plugins);
// sleep for cluster cancellation timeout to ensure scheduled cancellation task is actually executed
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
ensureSearchWasCancelled(searchResponse);
}
public void testCancellationDuringFetchPhase() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData(); indexTestData();
@ -196,8 +285,24 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
ensureSearchWasCancelled(searchResponse); ensureSearchWasCancelled(searchResponse);
} }
public void testCancellationOfScrollSearches() throws Exception { public void testCancellationDuringFetchPhaseUsingRequestParameter() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setCancelAfterTimeInterval(cancellationTimeout)
.addScriptField("test_field",
new Script(ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())
).execute();
awaitForBlock(plugins);
// sleep for request cancellation timeout to ensure scheduled cancellation task is actually executed
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
ensureSearchWasCancelled(searchResponse);
}
public void testCancellationOfScrollSearches() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData(); indexTestData();
@ -221,6 +326,29 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
} }
} }
public void testCancellationOfFirstScrollSearchRequestUsingRequestParameter() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setScroll(TimeValue.timeValueSeconds(10))
.setCancelAfterTimeInterval(cancellationTimeout)
.setSize(5)
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", SCRIPT_NAME, Collections.emptyMap())))
.execute();
awaitForBlock(plugins);
Thread.sleep(cancellationTimeout.getMillis());
disableBlocks(plugins);
SearchResponse response = ensureSearchWasCancelled(searchResponse);
if (response != null) {
// The response might not have failed on all shards - we need to clean scroll
logger.info("Cleaning scroll with id {}", response.getScrollId());
client().prepareClearScroll().addScrollId(response.getScrollId()).get();
}
}
public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exception { public void testCancellationOfScrollSearchesOnFollowupRequests() throws Exception {
@ -266,6 +394,93 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
client().prepareClearScroll().addScrollId(scrollId).get(); client().prepareClearScroll().addScrollId(scrollId).get();
} }
public void testNoCancellationOfScrollSearchOnFollowUpRequest() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
// Disable block so the first request would pass
disableBlocks(plugins);
TimeValue keepAlive = TimeValue.timeValueSeconds(5);
TimeValue cancellationTimeout = TimeValue.timeValueSeconds(2);
SearchResponse searchResponse = client().prepareSearch("test")
.setScroll(keepAlive)
.setCancelAfterTimeInterval(cancellationTimeout)
.setSize(2)
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME, Collections.emptyMap())))
.get();
assertNotNull(searchResponse.getScrollId());
// since the previous scroll response is received before cancellation timeout, the scheduled task will be cancelled. It will not
// be used for the subsequent scroll request, as request is of SearchScrollRequest type instead of SearchRequest type
// Enable block so the second request would block
for (ScriptedBlockPlugin plugin : plugins) {
plugin.reset();
plugin.enableBlock();
}
String scrollId = searchResponse.getScrollId();
ActionFuture<SearchResponse> scrollResponse = client().prepareSearchScroll(searchResponse.getScrollId())
.setScroll(keepAlive).execute();
awaitForBlock(plugins);
// sleep for cancellation timeout to ensure there is no scheduled task for cancellation
Thread.sleep(cancellationTimeout.getMillis());
disableBlocks(plugins);
// wait for response and ensure there is no failure
SearchResponse response = scrollResponse.get();
assertEquals(0, response.getFailedShards());
scrollId = response.getScrollId();
client().prepareClearScroll().addScrollId(scrollId).get();
}
public void testDisableCancellationAtRequestLevel() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder()
.put(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, cancellationTimeout)
.build()).get();
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setAllowPartialSearchResults(randomBoolean())
.setCancelAfterTimeInterval(NO_TIMEOUT)
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME, Collections.emptyMap())))
.execute();
awaitForBlock(plugins);
// sleep for cancellation timeout to ensure there is no scheduled task for cancellation
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
// ensure search was successful since cancellation was disabled at request level
assertEquals(0, searchResponse.get().getFailedShards());
}
public void testDisableCancellationAtClusterLevel() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder()
.put(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, NO_TIMEOUT)
.build()).get();
ActionFuture<SearchResponse> searchResponse = client().prepareSearch("test")
.setAllowPartialSearchResults(randomBoolean())
.setQuery(
scriptQuery(new Script(
ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME, Collections.emptyMap())))
.execute();
awaitForBlock(plugins);
// sleep for cancellation timeout to ensure there is no scheduled task for cancellation
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
// ensure search was successful since cancellation was disabled at request level
assertEquals(0, searchResponse.get().getFailedShards());
}
public void testCancelMultiSearch() throws Exception { public void testCancelMultiSearch() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory(); List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData(); indexTestData();
@ -287,6 +502,70 @@ public class SearchCancellationIT extends OpenSearchIntegTestCase {
} }
} }
public void testMSearchChildRequestCancellationWithClusterLevelTimeout() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue cancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder()
.put(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, cancellationTimeout)
.build()).get();
ActionFuture<MultiSearchResponse> mSearchResponse = client().prepareMultiSearch()
.add(client().prepareSearch("test").setAllowPartialSearchResults(randomBoolean())
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME,
Collections.emptyMap()))))
.add(client().prepareSearch("test").setAllowPartialSearchResults(randomBoolean()).setRequestCache(false)
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME,
Collections.emptyMap()))))
.execute();
awaitForBlock(plugins);
// sleep for cluster cancellation timeout to ensure scheduled cancellation task is actually executed
Thread.sleep(cancellationTimeout.getMillis());
// unblock the search thread
disableBlocks(plugins);
// both child requests are expected to fail
final Set<Integer> expectedFailedRequests = new HashSet<>();
expectedFailedRequests.add(0);
expectedFailedRequests.add(1);
ensureMSearchWasCancelled(mSearchResponse, expectedFailedRequests);
}
/**
* Verifies cancellation of sub search request with mix of request level and cluster level timeout parameter
* @throws Exception in case of unexpected errors
*/
public void testMSearchChildReqCancellationWithHybridTimeout() throws Exception {
List<ScriptedBlockPlugin> plugins = initBlockFactory();
indexTestData();
TimeValue reqCancellationTimeout = new TimeValue(2, TimeUnit.SECONDS);
TimeValue clusterCancellationTimeout = new TimeValue(3, TimeUnit.SECONDS);
client().admin().cluster().prepareUpdateSettings().setPersistentSettings(Settings.builder()
.put(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, clusterCancellationTimeout)
.build()).get();
ActionFuture<MultiSearchResponse> mSearchResponse = client().prepareMultiSearch()
.add(client().prepareSearch("test").setAllowPartialSearchResults(randomBoolean())
.setCancelAfterTimeInterval(reqCancellationTimeout)
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME,
Collections.emptyMap()))))
.add(client().prepareSearch("test").setAllowPartialSearchResults(randomBoolean())
.setCancelAfterTimeInterval(NO_TIMEOUT)
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME,
Collections.emptyMap()))))
.add(client().prepareSearch("test").setAllowPartialSearchResults(randomBoolean()).setRequestCache(false)
.setQuery(scriptQuery(new Script(ScriptType.INLINE, "mockscript", ScriptedBlockPlugin.SCRIPT_NAME,
Collections.emptyMap()))))
.execute();
awaitForBlock(plugins);
// sleep for cluster cancellation timeout to ensure scheduled cancellation task is actually executed
Thread.sleep(Math.max(reqCancellationTimeout.getMillis(), clusterCancellationTimeout.getMillis()));
// unblock the search thread
disableBlocks(plugins);
// only first and last child request are expected to fail
final Set<Integer> expectedFailedRequests = new HashSet<>();
expectedFailedRequests.add(0);
expectedFailedRequests.add(2);
ensureMSearchWasCancelled(mSearchResponse, expectedFailedRequests);
}
public static class ScriptedBlockPlugin extends MockScriptPlugin { public static class ScriptedBlockPlugin extends MockScriptPlugin {
static final String SCRIPT_NAME = "search_block"; static final String SCRIPT_NAME = "search_block";

View File

@ -66,6 +66,7 @@ import static org.opensearch.action.ValidateActions.addValidationError;
import static org.opensearch.common.xcontent.support.XContentMapValues.nodeBooleanValue; import static org.opensearch.common.xcontent.support.XContentMapValues.nodeBooleanValue;
import static org.opensearch.common.xcontent.support.XContentMapValues.nodeStringArrayValue; import static org.opensearch.common.xcontent.support.XContentMapValues.nodeStringArrayValue;
import static org.opensearch.common.xcontent.support.XContentMapValues.nodeStringValue; import static org.opensearch.common.xcontent.support.XContentMapValues.nodeStringValue;
import static org.opensearch.common.xcontent.support.XContentMapValues.nodeTimeValue;
/** /**
* A multi search API request. * A multi search API request.
@ -272,6 +273,9 @@ public class MultiSearchRequest extends ActionRequest implements CompositeIndice
allowNoIndices = value; allowNoIndices = value;
} else if ("ignore_throttled".equals(entry.getKey()) || "ignoreThrottled".equals(entry.getKey())) { } else if ("ignore_throttled".equals(entry.getKey()) || "ignoreThrottled".equals(entry.getKey())) {
ignoreThrottled = value; ignoreThrottled = value;
} else if ("cancel_after_time_interval".equals(entry.getKey()) ||
"cancelAfterTimeInterval".equals(entry.getKey())) {
searchRequest.setCancelAfterTimeInterval(nodeTimeValue(value, null));
} else { } else {
throw new IllegalArgumentException("key [" + entry.getKey() + "] is not supported in the metadata section"); throw new IllegalArgumentException("key [" + entry.getKey() + "] is not supported in the metadata section");
} }
@ -362,6 +366,9 @@ public class MultiSearchRequest extends ActionRequest implements CompositeIndice
if (request.allowPartialSearchResults() != null) { if (request.allowPartialSearchResults() != null) {
xContentBuilder.field("allow_partial_search_results", request.allowPartialSearchResults()); xContentBuilder.field("allow_partial_search_results", request.allowPartialSearchResults());
} }
if (request.getCancelAfterTimeInterval() != null) {
xContentBuilder.field("cancel_after_time_interval", request.getCancelAfterTimeInterval().getStringRep());
}
xContentBuilder.endObject(); xContentBuilder.endObject();
} }

View File

@ -33,6 +33,7 @@
package org.opensearch.action.search; package org.opensearch.action.search;
import org.opensearch.LegacyESVersion; import org.opensearch.LegacyESVersion;
import org.opensearch.Version;
import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.action.IndicesRequest; import org.opensearch.action.IndicesRequest;
@ -114,6 +115,8 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
private IndicesOptions indicesOptions = DEFAULT_INDICES_OPTIONS; private IndicesOptions indicesOptions = DEFAULT_INDICES_OPTIONS;
private TimeValue cancelAfterTimeInterval;
public SearchRequest() { public SearchRequest() {
this.localClusterAlias = null; this.localClusterAlias = null;
this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS; this.absoluteStartMillis = DEFAULT_ABSOLUTE_START_MILLIS;
@ -191,6 +194,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
this.localClusterAlias = localClusterAlias; this.localClusterAlias = localClusterAlias;
this.absoluteStartMillis = absoluteStartMillis; this.absoluteStartMillis = absoluteStartMillis;
this.finalReduce = finalReduce; this.finalReduce = finalReduce;
this.cancelAfterTimeInterval = searchRequest.cancelAfterTimeInterval;
} }
/** /**
@ -237,6 +241,10 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
if (in.getVersion().onOrAfter(LegacyESVersion.V_7_0_0)) { if (in.getVersion().onOrAfter(LegacyESVersion.V_7_0_0)) {
ccsMinimizeRoundtrips = in.readBoolean(); ccsMinimizeRoundtrips = in.readBoolean();
} }
if (in.getVersion().onOrAfter(Version.V_1_1_0)) {
cancelAfterTimeInterval = in.readOptionalTimeValue();
}
} }
@Override @Override
@ -271,6 +279,10 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
if (out.getVersion().onOrAfter(LegacyESVersion.V_7_0_0)) { if (out.getVersion().onOrAfter(LegacyESVersion.V_7_0_0)) {
out.writeBoolean(ccsMinimizeRoundtrips); out.writeBoolean(ccsMinimizeRoundtrips);
} }
if (out.getVersion().onOrAfter(Version.V_1_1_0)) {
out.writeOptionalTimeValue(cancelAfterTimeInterval);
}
} }
@Override @Override
@ -669,9 +681,17 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO : source.trackTotalHitsUpTo(); SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO : source.trackTotalHitsUpTo();
} }
public void setCancelAfterTimeInterval(TimeValue cancelAfterTimeInterval) {
this.cancelAfterTimeInterval = cancelAfterTimeInterval;
}
public TimeValue getCancelAfterTimeInterval() {
return cancelAfterTimeInterval;
}
@Override @Override
public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) { public SearchTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers); return new SearchTask(id, type, action, this::buildDescription, parentTaskId, headers, cancelAfterTimeInterval);
} }
public final String buildDescription() { public final String buildDescription() {
@ -718,14 +738,15 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
Objects.equals(allowPartialSearchResults, that.allowPartialSearchResults) && Objects.equals(allowPartialSearchResults, that.allowPartialSearchResults) &&
Objects.equals(localClusterAlias, that.localClusterAlias) && Objects.equals(localClusterAlias, that.localClusterAlias) &&
absoluteStartMillis == that.absoluteStartMillis && absoluteStartMillis == that.absoluteStartMillis &&
ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips; ccsMinimizeRoundtrips == that.ccsMinimizeRoundtrips &&
Objects.equals(cancelAfterTimeInterval, that.cancelAfterTimeInterval);
} }
@Override @Override
public int hashCode() { public int hashCode() {
return Objects.hash(searchType, Arrays.hashCode(indices), routing, preference, source, requestCache, return Objects.hash(searchType, Arrays.hashCode(indices), routing, preference, source, requestCache,
scroll, Arrays.hashCode(types), indicesOptions, batchedReduceSize, maxConcurrentShardRequests, preFilterShardSize, scroll, Arrays.hashCode(types), indicesOptions, batchedReduceSize, maxConcurrentShardRequests, preFilterShardSize,
allowPartialSearchResults, localClusterAlias, absoluteStartMillis, ccsMinimizeRoundtrips); allowPartialSearchResults, localClusterAlias, absoluteStartMillis, ccsMinimizeRoundtrips, cancelAfterTimeInterval);
} }
@Override @Override
@ -746,6 +767,7 @@ public class SearchRequest extends ActionRequest implements IndicesRequest.Repla
", localClusterAlias=" + localClusterAlias + ", localClusterAlias=" + localClusterAlias +
", getOrCreateAbsoluteStartMillis=" + absoluteStartMillis + ", getOrCreateAbsoluteStartMillis=" + absoluteStartMillis +
", ccsMinimizeRoundtrips=" + ccsMinimizeRoundtrips + ", ccsMinimizeRoundtrips=" + ccsMinimizeRoundtrips +
", source=" + source + '}'; ", source=" + source +
", cancelAfterTimeInterval=" + cancelAfterTimeInterval + "}";
} }
} }

View File

@ -626,4 +626,12 @@ public class SearchRequestBuilder extends ActionRequestBuilder<SearchRequest, Se
this.request.setPreFilterShardSize(preFilterShardSize); this.request.setPreFilterShardSize(preFilterShardSize);
return this; return this;
} }
/**
* Request level time interval to control how long search is allowed to execute after which it is cancelled.
*/
public SearchRequestBuilder setCancelAfterTimeInterval(TimeValue cancelAfterTimeInterval) {
this.request.setCancelAfterTimeInterval(cancelAfterTimeInterval);
return this;
}
} }

View File

@ -32,12 +32,15 @@
package org.opensearch.action.search; package org.opensearch.action.search;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.tasks.CancellableTask; import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskId;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier; import java.util.function.Supplier;
import static org.opensearch.search.SearchService.NO_TIMEOUT;
/** /**
* Task storing information about a currently running {@link SearchRequest}. * Task storing information about a currently running {@link SearchRequest}.
*/ */
@ -46,9 +49,14 @@ public class SearchTask extends CancellableTask {
private final Supplier<String> descriptionSupplier; private final Supplier<String> descriptionSupplier;
private SearchProgressListener progressListener = SearchProgressListener.NOOP; private SearchProgressListener progressListener = SearchProgressListener.NOOP;
public SearchTask(long id, String type, String action, Supplier<String> descriptionSupplier, public SearchTask(long id, String type, String action, Supplier<String> descriptionSupplier, TaskId parentTaskId,
TaskId parentTaskId, Map<String, String> headers) { Map<String, String> headers) {
super(id, type, action, null, parentTaskId, headers); this(id, type, action, descriptionSupplier, parentTaskId, headers, NO_TIMEOUT);
}
public SearchTask(long id, String type, String action, Supplier<String> descriptionSupplier, TaskId parentTaskId,
Map<String, String> headers, TimeValue cancelAfterTimeInterval) {
super(id, type, action, null, parentTaskId, headers, cancelAfterTimeInterval);
this.descriptionSupplier = descriptionSupplier; this.descriptionSupplier = descriptionSupplier;
} }

View File

@ -41,6 +41,7 @@ import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction; import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.IndicesOptions;
import org.opensearch.action.support.TimeoutTaskCancellationUtility;
import org.opensearch.client.Client; import org.opensearch.client.Client;
import org.opensearch.client.OriginSettingClient; import org.opensearch.client.OriginSettingClient;
import org.opensearch.client.node.NodeClient; import org.opensearch.client.node.NodeClient;
@ -81,6 +82,7 @@ import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.internal.SearchContext; import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.profile.ProfileShardResult; import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.SearchProfileShardResults; import org.opensearch.search.profile.SearchProfileShardResults;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.Task; import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskId; import org.opensearch.tasks.TaskId;
import org.opensearch.threadpool.ThreadPool; import org.opensearch.threadpool.ThreadPool;
@ -121,6 +123,13 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting( public static final Setting<Long> SHARD_COUNT_LIMIT_SETTING = Setting.longSetting(
"action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope); "action.search.shard_count.limit", Long.MAX_VALUE, 1L, Property.Dynamic, Property.NodeScope);
// cluster level setting for timeout based search cancellation. If search request level parameter is present then that will take
// precedence over the cluster setting value
public static final String SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY = "search.cancel_after_time_interval";
public static final Setting<TimeValue> SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING =
Setting.timeSetting(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING_KEY, SearchService.NO_TIMEOUT, Setting.Property.Dynamic,
Setting.Property.NodeScope);
private final NodeClient client; private final NodeClient client;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final ClusterService clusterService; private final ClusterService clusterService;
@ -239,6 +248,14 @@ public class TransportSearchAction extends HandledTransportAction<SearchRequest,
@Override @Override
protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<SearchResponse> listener) { protected void doExecute(Task task, SearchRequest searchRequest, ActionListener<SearchResponse> listener) {
// only if task is of type CancellableTask and support cancellation on timeout, treat this request eligible for timeout based
// cancellation. There may be other top level requests like AsyncSearch which is using SearchRequest internally and has it's own
// cancellation mechanism. For such cases, the SearchRequest when created can override the createTask and set the
// cancelAfterTimeInterval to NO_TIMEOUT and bypass this mechanism
if (task instanceof CancellableTask) {
listener = TimeoutTaskCancellationUtility.wrapWithCancellationListener(client, (CancellableTask) task,
clusterService.getClusterSettings(), listener);
}
executeRequest(task, searchRequest, this::searchAsyncAction, listener); executeRequest(task, searchRequest, this::searchAsyncAction, listener);
} }

View File

@ -0,0 +1,135 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.action.support;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.ActionListener;
import org.opensearch.action.admin.cluster.node.tasks.cancel.CancelTasksRequest;
import org.opensearch.client.OriginSettingClient;
import org.opensearch.client.node.NodeClient;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.search.SearchService;
import org.opensearch.tasks.CancellableTask;
import org.opensearch.tasks.TaskId;
import org.opensearch.threadpool.Scheduler;
import org.opensearch.threadpool.ThreadPool;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import static org.opensearch.action.admin.cluster.node.tasks.get.GetTaskAction.TASKS_ORIGIN;
import static org.opensearch.action.search.TransportSearchAction.SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING;
public class TimeoutTaskCancellationUtility {
private static final Logger logger = LogManager.getLogger(TimeoutTaskCancellationUtility.class);
/**
* Wraps a listener with a timeout listener {@link TimeoutRunnableListener} to schedule the task cancellation for provided tasks on
* generic thread pool
* @param client - {@link NodeClient}
* @param taskToCancel - task to schedule cancellation for
* @param clusterSettings - {@link ClusterSettings}
* @param listener - original listener associated with the task
* @return wrapped listener
*/
public static <Response> ActionListener<Response> wrapWithCancellationListener(NodeClient client, CancellableTask taskToCancel,
ClusterSettings clusterSettings, ActionListener<Response> listener) {
final TimeValue globalTimeout = clusterSettings.get(SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING);
final TimeValue timeoutInterval = (taskToCancel.getCancellationTimeout() == null) ? globalTimeout
: taskToCancel.getCancellationTimeout();
// Note: -1 (or no timeout) will help to turn off cancellation. The combinations will be request level set at -1 or request level
// set to null and cluster level set to -1.
ActionListener<Response> listenerToReturn = listener;
if (timeoutInterval.equals(SearchService.NO_TIMEOUT)) {
return listenerToReturn;
}
try {
final TimeoutRunnableListener<Response> wrappedListener = new TimeoutRunnableListener<>(timeoutInterval, listener, () -> {
final CancelTasksRequest cancelTasksRequest = new CancelTasksRequest();
cancelTasksRequest.setTaskId(new TaskId(client.getLocalNodeId(), taskToCancel.getId()));
cancelTasksRequest.setReason("Cancellation timeout of " + timeoutInterval + " is expired");
// force the origin to execute the cancellation as a system user
new OriginSettingClient(client, TASKS_ORIGIN).admin().cluster()
.cancelTasks(cancelTasksRequest, ActionListener.wrap(r -> logger.debug(
"Scheduled cancel task with timeout: {} for original task: {} is successfully completed", timeoutInterval,
cancelTasksRequest.getTaskId()),
e -> logger.error(new ParameterizedMessage("Scheduled cancel task with timeout: {} for original task: {} is failed",
timeoutInterval, cancelTasksRequest.getTaskId()), e))
);
});
wrappedListener.cancellable = client.threadPool().schedule(wrappedListener, timeoutInterval, ThreadPool.Names.GENERIC);
listenerToReturn = wrappedListener;
} catch (Exception ex) {
// if there is any exception in scheduling the cancellation task then continue without it
logger.warn("Failed to schedule the cancellation task for original task: {}, will continue without it", taskToCancel.getId());
}
return listenerToReturn;
}
/**
* Timeout listener which executes the provided runnable after timeout is expired and if a response/failure is not yet received.
* If either a response/failure is received before timeout then the scheduled task is cancelled and response/failure is sent back to
* the original listener.
*/
private static class TimeoutRunnableListener<Response> implements ActionListener<Response>, Runnable {
private static final Logger logger = LogManager.getLogger(TimeoutRunnableListener.class);
// Runnable to execute after timeout
private final TimeValue timeout;
private final ActionListener<Response> originalListener;
private final Runnable timeoutRunnable;
private final AtomicBoolean executeRunnable = new AtomicBoolean(true);
private volatile Scheduler.ScheduledCancellable cancellable;
private final long creationTime;
TimeoutRunnableListener(TimeValue timeout, ActionListener<Response> listener, Runnable runAfterTimeout) {
this.timeout = timeout;
this.originalListener = listener;
this.timeoutRunnable = runAfterTimeout;
this.creationTime = System.nanoTime();
}
@Override public void onResponse(Response response) {
checkAndCancel();
originalListener.onResponse(response);
}
@Override public void onFailure(Exception e) {
checkAndCancel();
originalListener.onFailure(e);
}
@Override public void run() {
try {
if (executeRunnable.compareAndSet(true, false)) {
timeoutRunnable.run();
} // else do nothing since either response/failure is already sent to client
} catch (Exception ex) {
// ignore the exception
logger.error(new ParameterizedMessage("Ignoring the failure to run the provided runnable after timeout of {} with " +
"exception", timeout), ex);
}
}
private void checkAndCancel() {
if (executeRunnable.compareAndSet(true, false)) {
logger.debug("Aborting the scheduled cancel task after {}",
TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - creationTime));
// timer has not yet expired so cancel it
cancellable.cancel();
}
}
}
}

View File

@ -345,6 +345,7 @@ public final class ClusterSettings extends AbstractScopedSettings {
SearchService.DEFAULT_ALLOW_PARTIAL_SEARCH_RESULTS, SearchService.DEFAULT_ALLOW_PARTIAL_SEARCH_RESULTS,
ElectMasterService.DISCOVERY_ZEN_MINIMUM_MASTER_NODES_SETTING, ElectMasterService.DISCOVERY_ZEN_MINIMUM_MASTER_NODES_SETTING,
TransportSearchAction.SHARD_COUNT_LIMIT_SETTING, TransportSearchAction.SHARD_COUNT_LIMIT_SETTING,
TransportSearchAction.SEARCH_CANCEL_AFTER_TIME_INTERVAL_SETTING,
RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE, RemoteClusterService.REMOTE_CLUSTER_SKIP_UNAVAILABLE,
RemoteClusterService.SEARCH_REMOTE_CLUSTER_SKIP_UNAVAILABLE, RemoteClusterService.SEARCH_REMOTE_CLUSTER_SKIP_UNAVAILABLE,
SniffConnectionStrategy.REMOTE_CONNECTIONS_PER_CLUSTER, SniffConnectionStrategy.REMOTE_CONNECTIONS_PER_CLUSTER,

View File

@ -44,6 +44,7 @@ import org.opensearch.common.collect.Tuple;
import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.XContent; import org.opensearch.common.xcontent.XContent;
import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType; import org.opensearch.common.xcontent.XContentType;
@ -158,6 +159,7 @@ public class RestMultiSearchAction extends BaseRestHandler {
multiRequest.add(searchRequest); multiRequest.add(searchRequest);
}); });
List<SearchRequest> requests = multiRequest.requests(); List<SearchRequest> requests = multiRequest.requests();
final TimeValue cancelAfterTimeInterval = restRequest.paramAsTime("cancel_after_time_interval", null);
for (SearchRequest request : requests) { for (SearchRequest request : requests) {
// preserve if it's set on the request // preserve if it's set on the request
if (preFilterShardSize != null && request.getPreFilterShardSize() == null) { if (preFilterShardSize != null && request.getPreFilterShardSize() == null) {
@ -166,6 +168,11 @@ public class RestMultiSearchAction extends BaseRestHandler {
if (maxConcurrentShardRequests != null) { if (maxConcurrentShardRequests != null) {
request.setMaxConcurrentShardRequests(maxConcurrentShardRequests); request.setMaxConcurrentShardRequests(maxConcurrentShardRequests);
} }
// if cancel_after_time_interval parameter is set at per search request level than that is used otherwise one set at
// multi search request level will be used
if (request.getCancelAfterTimeInterval() == null) {
request.setCancelAfterTimeInterval(cancelAfterTimeInterval);
}
} }
return multiRequest; return multiRequest;
} }

View File

@ -208,6 +208,8 @@ public class RestSearchAction extends BaseRestHandler {
searchRequest.setCcsMinimizeRoundtrips( searchRequest.setCcsMinimizeRoundtrips(
request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips())); request.paramAsBoolean("ccs_minimize_roundtrips", searchRequest.isCcsMinimizeRoundtrips()));
} }
searchRequest.setCancelAfterTimeInterval(request.paramAsTime("cancel_after_time_interval", null));
} }
/** /**

View File

@ -62,7 +62,7 @@ public class DfsPhase {
@Override @Override
public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException { public TermStatistics termStatistics(Term term, int docFreq, long totalTermFreq) throws IOException {
if (context.isCancelled()) { if (context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
} }
TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq); TermStatistics ts = super.termStatistics(term, docFreq, totalTermFreq);
if (ts != null) { if (ts != null) {
@ -74,7 +74,7 @@ public class DfsPhase {
@Override @Override
public CollectionStatistics collectionStatistics(String field) throws IOException { public CollectionStatistics collectionStatistics(String field) throws IOException {
if (context.isCancelled()) { if (context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
} }
CollectionStatistics cs = super.collectionStatistics(field); CollectionStatistics cs = super.collectionStatistics(field);
if (cs != null) { if (cs != null) {

View File

@ -108,7 +108,7 @@ public class FetchPhase {
} }
if (context.isCancelled()) { if (context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
} }
if (context.docIdsToLoadSize() == 0) { if (context.docIdsToLoadSize() == 0) {
@ -140,7 +140,7 @@ public class FetchPhase {
boolean hasSequentialDocs = hasSequentialDocs(docs); boolean hasSequentialDocs = hasSequentialDocs(docs);
for (int index = 0; index < context.docIdsToLoadSize(); index++) { for (int index = 0; index < context.docIdsToLoadSize(); index++) {
if (context.isCancelled()) { if (context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
} }
int docId = docs[index].docId; int docId = docs[index].docId;
try { try {
@ -181,7 +181,7 @@ public class FetchPhase {
} }
} }
if (context.isCancelled()) { if (context.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + context.getTask().getReasonCancelled());
} }
TotalHits totalHits = context.queryResult().getTotalHits(); TotalHits totalHits = context.queryResult().getTotalHits();

View File

@ -126,7 +126,7 @@ public class QueryPhase {
cancellation = context.searcher().addQueryCancellation(() -> { cancellation = context.searcher().addQueryCancellation(() -> {
SearchShardTask task = context.getTask(); SearchShardTask task = context.getTask();
if (task != null && task.isCancelled()) { if (task != null && task.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled());
} }
}); });
} else { } else {
@ -295,7 +295,7 @@ public class QueryPhase {
searcher.addQueryCancellation(() -> { searcher.addQueryCancellation(() -> {
SearchShardTask task = searchContext.getTask(); SearchShardTask task = searchContext.getTask();
if (task != null && task.isCancelled()) { if (task != null && task.isCancelled()) {
throw new TaskCancelledException("cancelled"); throw new TaskCancelledException("cancelled task with reason: " + task.getReasonCancelled());
} }
}); });
} }

View File

@ -33,10 +33,13 @@
package org.opensearch.tasks; package org.opensearch.tasks;
import org.opensearch.common.Nullable; import org.opensearch.common.Nullable;
import org.opensearch.common.unit.TimeValue;
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import static org.opensearch.search.SearchService.NO_TIMEOUT;
/** /**
* A task that can be canceled * A task that can be canceled
*/ */
@ -44,9 +47,16 @@ public abstract class CancellableTask extends Task {
private volatile String reason; private volatile String reason;
private final AtomicBoolean cancelled = new AtomicBoolean(false); private final AtomicBoolean cancelled = new AtomicBoolean(false);
private final TimeValue cancelAfterTimeInterval;
public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) { public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
this(id, type, action, description, parentTaskId, headers, NO_TIMEOUT);
}
public CancellableTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers,
TimeValue cancelAfterTimeInterval) {
super(id, type, action, description, parentTaskId, headers); super(id, type, action, description, parentTaskId, headers);
this.cancelAfterTimeInterval = cancelAfterTimeInterval;
} }
/** /**
@ -77,6 +87,10 @@ public abstract class CancellableTask extends Task {
return cancelled.get(); return cancelled.get();
} }
public TimeValue getCancellationTimeout() {
return cancelAfterTimeInterval;
}
/** /**
* The reason the task was cancelled or null if it hasn't been cancelled. * The reason the task was cancelled or null if it hasn't been cancelled.
*/ */

View File

@ -32,6 +32,7 @@
package org.opensearch.action.search; package org.opensearch.action.search;
import org.opensearch.Version;
import org.opensearch.action.support.IndicesOptions; import org.opensearch.action.support.IndicesOptions;
import org.opensearch.common.CheckedBiConsumer; import org.opensearch.common.CheckedBiConsumer;
import org.opensearch.common.CheckedRunnable; import org.opensearch.common.CheckedRunnable;
@ -39,7 +40,9 @@ import org.opensearch.common.ParseField;
import org.opensearch.common.Strings; import org.opensearch.common.Strings;
import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.logging.DeprecationLogger; import org.opensearch.common.logging.DeprecationLogger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentHelper;
@ -54,6 +57,7 @@ import org.opensearch.search.Scroll;
import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.StreamsUtils; import org.opensearch.test.StreamsUtils;
import org.opensearch.test.VersionUtils;
import org.opensearch.test.rest.FakeRestRequest; import org.opensearch.test.rest.FakeRestRequest;
import java.io.IOException; import java.io.IOException;
@ -62,6 +66,7 @@ import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.TimeUnit;
import static java.util.Collections.singletonList; import static java.util.Collections.singletonList;
import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest; import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest;
@ -136,6 +141,38 @@ public class MultiSearchRequestTests extends OpenSearchTestCase {
assertThat(request.requests().get(0).types().length, equalTo(0)); assertThat(request.requests().get(0).types().length, equalTo(0));
} }
public void testCancelAfterIntervalAtParentAndFewChildRequest() throws Exception {
final String requestContent = "{\"index\":\"test\", \"expand_wildcards\" : \"open,closed\", " +
"\"cancel_after_time_interval\" : \"10s\"}\r\n" +
"{\"query\" : {\"match_all\" :{}}}\r\n {\"search_type\" : \"dfs_query_then_fetch\"}\n" +
"{\"query\" : {\"match_all\" :{}}}\r\n";
FakeRestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry())
.withContent(new BytesArray(requestContent), XContentType.JSON)
.withParams(Collections.singletonMap("cancel_after_time_interval", "20s"))
.build();
MultiSearchRequest request = RestMultiSearchAction.parseRequest(restRequest, null, true);
assertThat(request.requests().size(), equalTo(2));
assertThat(request.requests().get(0).indices()[0], equalTo("test"));
// verifies that child search request parameter value is used for first search request
assertEquals(new TimeValue(10, TimeUnit.SECONDS), request.requests().get(0).getCancelAfterTimeInterval());
// verifies that parent msearch parameter value is used for second search request
assertEquals(request.requests().get(1).searchType(), SearchType.DFS_QUERY_THEN_FETCH);
assertEquals(new TimeValue(20, TimeUnit.SECONDS), request.requests().get(1).getCancelAfterTimeInterval());
}
public void testOnlyParentMSearchRequestWithCancelAfterTimeIntervalParameter() throws IOException {
final String requestContent = "{\"index\":\"test\", \"expand_wildcards\" : \"open,closed\"}}\r\n" +
"{\"query\" : {\"match_all\" :{}}}\r\n";
FakeRestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry())
.withContent(new BytesArray(requestContent), XContentType.JSON)
.withParams(Collections.singletonMap("cancel_after_time_interval", "20s"))
.build();
MultiSearchRequest request = RestMultiSearchAction.parseRequest(restRequest, null, true);
assertThat(request.requests().size(), equalTo(1));
assertThat(request.requests().get(0).indices()[0], equalTo("test"));
assertEquals(new TimeValue(20, TimeUnit.SECONDS), request.requests().get(0).getCancelAfterTimeInterval());
}
public void testDefaultIndicesOptions() throws IOException { public void testDefaultIndicesOptions() throws IOException {
final String requestContent = "{\"index\":\"test\", \"expand_wildcards\" : \"open,closed\"}}\r\n" + final String requestContent = "{\"index\":\"test\", \"expand_wildcards\" : \"open,closed\"}}\r\n" +
"{\"query\" : {\"match_all\" :{}}}\r\n"; "{\"query\" : {\"match_all\" :{}}}\r\n";
@ -316,6 +353,12 @@ public class MultiSearchRequestTests extends OpenSearchTestCase {
new ParseField(MatchAllQueryBuilder.NAME), (p, c) -> MatchAllQueryBuilder.fromXContent(p)))); new ParseField(MatchAllQueryBuilder.NAME), (p, c) -> MatchAllQueryBuilder.fromXContent(p))));
} }
@Override
protected NamedWriteableRegistry writableRegistry() {
return new NamedWriteableRegistry(singletonList(new NamedWriteableRegistry.Entry(QueryBuilder.class,
MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)));
}
public void testMultiLineSerialization() throws IOException { public void testMultiLineSerialization() throws IOException {
int iters = 16; int iters = 16;
for (int i = 0; i < iters; i++) { for (int i = 0; i < iters; i++) {
@ -338,6 +381,24 @@ public class MultiSearchRequestTests extends OpenSearchTestCase {
} }
} }
public void testSerDeWithCancelAfterTimeIntervalParameterAndRandomVersion() throws IOException {
final String requestContent = "{\"index\":\"test\", \"expand_wildcards\" : \"open,closed\", " +
"\"cancel_after_time_interval\" : \"10s\"}\r\n{\"query\" : {\"match_all\" :{}}}\r\n";
FakeRestRequest restRequest = new FakeRestRequest.Builder(xContentRegistry())
.withContent(new BytesArray(requestContent), XContentType.JSON)
.build();
Version version = VersionUtils.randomVersion(random());
MultiSearchRequest originalRequest = RestMultiSearchAction.parseRequest(restRequest, null, true);
MultiSearchRequest deserializedRequest = copyWriteable(originalRequest, writableRegistry(), MultiSearchRequest::new, version);
if (version.before(Version.V_1_1_0)) {
assertNull(deserializedRequest.requests().get(0).getCancelAfterTimeInterval());
} else {
assertEquals(originalRequest.requests().get(0).getCancelAfterTimeInterval(),
deserializedRequest.requests().get(0).getCancelAfterTimeInterval());
}
}
public void testWritingExpandWildcards() throws IOException { public void testWritingExpandWildcards() throws IOException {
assertExpandWildcardsValue(IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), true, true, true, randomBoolean(), assertExpandWildcardsValue(IndicesOptions.fromOptions(randomBoolean(), randomBoolean(), true, true, true, randomBoolean(),
randomBoolean(), randomBoolean(), randomBoolean()), "all"); randomBoolean(), randomBoolean(), randomBoolean()), "all");

View File

@ -115,6 +115,12 @@ public class SearchRequestTests extends AbstractSearchTestCase {
assertEquals(searchRequest.getAbsoluteStartMillis(), deserializedRequest.getAbsoluteStartMillis()); assertEquals(searchRequest.getAbsoluteStartMillis(), deserializedRequest.getAbsoluteStartMillis());
assertEquals(searchRequest.isFinalReduce(), deserializedRequest.isFinalReduce()); assertEquals(searchRequest.isFinalReduce(), deserializedRequest.isFinalReduce());
} }
if (version.onOrAfter(Version.V_1_1_0)) {
assertEquals(searchRequest.getCancelAfterTimeInterval(), deserializedRequest.getCancelAfterTimeInterval());
} else {
assertNull(deserializedRequest.getCancelAfterTimeInterval());
}
} }
public void testReadFromPre6_7_0() throws IOException { public void testReadFromPre6_7_0() throws IOException {
@ -261,6 +267,8 @@ public class SearchRequestTests extends AbstractSearchTestCase {
() -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)))); () -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH))));
mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder))); mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder)));
mutators.add(() -> mutation.setCcsMinimizeRoundtrips(searchRequest.isCcsMinimizeRoundtrips() == false)); mutators.add(() -> mutation.setCcsMinimizeRoundtrips(searchRequest.isCcsMinimizeRoundtrips() == false));
mutators.add(() -> mutation.setCancelAfterTimeInterval(searchRequest.getCancelAfterTimeInterval() != null
? null : TimeValue.parseTimeValue(randomTimeValue(), null, "cancel_after_time_interval")));
randomFrom(mutators).run(); randomFrom(mutators).run();
return mutation; return mutation;
} }

View File

@ -130,6 +130,10 @@ public class RandomSearchRequestGenerator {
if (randomBoolean()) { if (randomBoolean()) {
searchRequest.source(randomSearchSourceBuilder.get()); searchRequest.source(randomSearchSourceBuilder.get());
} }
if (randomBoolean()) {
searchRequest.setCancelAfterTimeInterval(
TimeValue.parseTimeValue(randomTimeValue(), null, "cancel_after_time_interval"));
}
return searchRequest; return searchRequest;
} }