Plugins: Replace Rest filters with RestHandler wrapper (#21905)

* Plugins: Replace Rest filters with RestHandler wrapper

RestFilters are a complex way of allowing plugins to add extra code
before rest actions are executed. This change removes rest filters, and
replaces with a wrapper which a single plugin may provide.
This commit is contained in:
Ryan Ernst 2016-12-02 14:54:51 -08:00 committed by GitHub
parent 0ecdef026d
commit 34eb23e98e
17 changed files with 112 additions and 452 deletions

View File

@ -19,12 +19,20 @@
package org.elasticsearch.action;
import java.io.IOException;
import java.util.function.Consumer;
/**
* A listener for action responses or failures.
*/
public interface ActionListener<Response> {
/** A consumer interface which allows throwing checked exceptions. */
@FunctionalInterface
interface CheckedConsumer<T> {
void accept(T t) throws Exception;
}
/**
* Handle action response. This response may constitute a failure or a
* success but it is up to the listener to make that decision.
@ -45,7 +53,7 @@ public interface ActionListener<Response> {
* @param <Response> the type of the response
* @return a listener that listens for responses and invokes the consumer when received
*/
static <Response> ActionListener<Response> wrap(Consumer<Response> onResponse, Consumer<Exception> onFailure) {
static <Response> ActionListener<Response> wrap(CheckedConsumer<Response> onResponse, Consumer<Exception> onFailure) {
return new ActionListener<Response>() {
@Override
public void onResponse(Response response) {

View File

@ -24,8 +24,10 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.admin.cluster.allocation.ClusterAllocationExplainAction;
import org.elasticsearch.action.admin.cluster.allocation.TransportClusterAllocationExplainAction;
import org.elasticsearch.action.admin.cluster.health.ClusterHealthAction;
@ -201,6 +203,7 @@ import org.elasticsearch.common.NamedRegistry;
import org.elasticsearch.common.inject.AbstractModule;
import org.elasticsearch.common.inject.multibindings.MapBinder;
import org.elasticsearch.common.inject.multibindings.Multibinder;
import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Settings;
@ -310,6 +313,7 @@ import org.elasticsearch.rest.action.search.RestMultiSearchAction;
import org.elasticsearch.rest.action.search.RestSearchAction;
import org.elasticsearch.rest.action.search.RestSearchScrollAction;
import org.elasticsearch.rest.action.search.RestSuggestAction;
import org.elasticsearch.threadpool.ThreadPool;
import static java.util.Collections.unmodifiableList;
import static java.util.Collections.unmodifiableMap;
@ -319,6 +323,8 @@ import static java.util.Collections.unmodifiableMap;
*/
public class ActionModule extends AbstractModule {
private static final Logger logger = ESLoggerFactory.getLogger(ActionModule.class);
private final boolean transportClient;
private final Settings settings;
private final List<ActionPlugin> actionPlugins;
@ -329,7 +335,7 @@ public class ActionModule extends AbstractModule {
private final RestController restController;
public ActionModule(boolean ingestEnabled, boolean transportClient, Settings settings, IndexNameExpressionResolver resolver,
ClusterSettings clusterSettings, List<ActionPlugin> actionPlugins) {
ClusterSettings clusterSettings, ThreadPool threadPool, List<ActionPlugin> actionPlugins) {
this.transportClient = transportClient;
this.settings = settings;
this.actionPlugins = actionPlugins;
@ -338,7 +344,18 @@ public class ActionModule extends AbstractModule {
autoCreateIndex = transportClient ? null : new AutoCreateIndex(settings, clusterSettings, resolver);
destructiveOperations = new DestructiveOperations(settings, clusterSettings);
Set<String> headers = actionPlugins.stream().flatMap(p -> p.getRestHeaders().stream()).collect(Collectors.toSet());
restController = new RestController(settings, headers);
UnaryOperator<RestHandler> restWrapper = null;
for (ActionPlugin plugin : actionPlugins) {
UnaryOperator<RestHandler> newRestWrapper = plugin.getRestHandlerWrapper(threadPool.getThreadContext());
if (newRestWrapper != null) {
logger.debug("Using REST wrapper from plugin " + plugin.getClass().getName());
if (restWrapper != null) {
throw new IllegalArgumentException("Cannot have more than one plugin implementing a REST wrapper");
}
restWrapper = newRestWrapper;
}
}
restController = new RestController(settings, headers, restWrapper);
}
public Map<String, ActionHandler<?, ?>> getActions() {

View File

@ -148,7 +148,7 @@ public abstract class TransportClient extends AbstractClient {
}
modules.add(b -> b.bind(ThreadPool.class).toInstance(threadPool));
ActionModule actionModule = new ActionModule(false, true, settings, null, settingsModule.getClusterSettings(),
pluginsService.filterPlugins(ActionPlugin.class));
threadPool, pluginsService.filterPlugins(ActionPlugin.class));
modules.add(actionModule);
CircuitBreakerService circuitBreakerService = Node.createCircuitBreakerService(settingsModule.getSettings(),

View File

@ -19,6 +19,8 @@
package org.elasticsearch.http;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.breaker.CircuitBreaker;
@ -112,7 +114,13 @@ public class HttpServer extends AbstractLifecycleComponent implements HttpServer
responseChannel = new ResourceHandlingHttpChannel(channel, circuitBreakerService, contentLength);
restController.dispatchRequest(request, responseChannel, client, threadContext);
} catch (Exception e) {
restController.sendErrorResponse(request, responseChannel, e);
try {
responseChannel.sendResponse(new BytesRestResponse(channel, e));
} catch (Exception inner) {
inner.addSuppressed(e);
logger.error((Supplier<?>) () ->
new ParameterizedMessage("failed to send failure response for uri [{}]", request.uri()), inner);
}
}
}

View File

@ -117,7 +117,6 @@ import org.elasticsearch.plugins.RepositoryPlugin;
import org.elasticsearch.plugins.ScriptPlugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.repositories.RepositoriesModule;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.script.ScriptModule;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchExtRegistry;
@ -346,7 +345,7 @@ public class Node implements Closeable {
SearchModule searchModule = new SearchModule(settings, false, pluginsService.filterPlugins(SearchPlugin.class));
ActionModule actionModule = new ActionModule(DiscoveryNode.isIngestNode(settings), false, settings,
clusterModule.getIndexNameExpressionResolver(), settingsModule.getClusterSettings(),
pluginsService.filterPlugins(ActionPlugin.class));
threadPool, pluginsService.filterPlugins(ActionPlugin.class));
modules.add(actionModule);
modules.add(new GatewayModule());
modules.add(new RepositoriesModule(this.environment, pluginsService.filterPlugins(RepositoryPlugin.class)));
@ -546,7 +545,6 @@ public class Node implements Closeable {
injector.getInstance(RoutingService.class).start();
injector.getInstance(SearchService.class).start();
injector.getInstance(MonitorService.class).start();
injector.getInstance(RestController.class).start();
final ClusterService clusterService = injector.getInstance(ClusterService.class);
@ -670,7 +668,6 @@ public class Node implements Closeable {
injector.getInstance(MonitorService.class).stop();
injector.getInstance(GatewayService.class).stop();
injector.getInstance(SearchService.class).stop();
injector.getInstance(RestController.class).stop();
injector.getInstance(TransportService.class).stop();
pluginLifecycleComponents.forEach(LifecycleComponent::stop);
@ -731,8 +728,6 @@ public class Node implements Closeable {
toClose.add(injector.getInstance(GatewayService.class));
toClose.add(() -> stopWatch.stop().start("search"));
toClose.add(injector.getInstance(SearchService.class));
toClose.add(() -> stopWatch.stop().start("rest"));
toClose.add(injector.getInstance(RestController.class));
toClose.add(() -> stopWatch.stop().start("transport"));
toClose.add(injector.getInstance(TransportService.class));

View File

@ -26,12 +26,14 @@ import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.action.support.TransportAction;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestHandler;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.function.UnaryOperator;
/**
* An additional extension point for {@link Plugin}s that extends Elasticsearch's scripting functionality. Implement it like this:
@ -72,6 +74,15 @@ public interface ActionPlugin {
return Collections.emptyList();
}
/**
* Returns a function used to wrap each rest request before handling the request.
*
* Note: Only one installed plugin may implement a rest wrapper.
*/
default UnaryOperator<RestHandler> getRestHandlerWrapper(ThreadContext threadContext) {
return null;
}
final class ActionHandler<Request extends ActionRequest, Response extends ActionResponse> {
private final GenericAction<Request, Response> action;
private final Class<? extends TransportAction<Request, Response>> transportAction;

View File

@ -18,6 +18,11 @@
*/
package org.elasticsearch.plugins;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings;
@ -28,11 +33,6 @@ import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportInterceptor;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Supplier;
/**
* Plugin for extending network and transport related classes
*/
@ -56,6 +56,7 @@ public interface NetworkPlugin {
NetworkService networkService) {
return Collections.emptyMap();
}
/**
* Returns a map of {@link HttpServerTransport} suppliers.
* See {@link org.elasticsearch.common.network.NetworkModule#HTTP_TYPE_SETTING} to configure a specific implementation.

View File

@ -19,27 +19,26 @@
package org.elasticsearch.rest;
import java.io.IOException;
import java.util.Objects;
import java.util.Set;
import java.util.function.UnaryOperator;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.path.PathTrie;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import static org.elasticsearch.rest.RestStatus.BAD_REQUEST;
import static org.elasticsearch.rest.RestStatus.OK;
public class RestController extends AbstractLifecycleComponent {
public class RestController extends AbstractComponent {
private final PathTrie<RestHandler> getHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> postHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> putHandlers = new PathTrie<>(RestUtils.REST_DECODER);
@ -47,43 +46,18 @@ public class RestController extends AbstractLifecycleComponent {
private final PathTrie<RestHandler> headHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final PathTrie<RestHandler> optionsHandlers = new PathTrie<>(RestUtils.REST_DECODER);
private final RestHandlerFilter handlerFilter = new RestHandlerFilter();
private final UnaryOperator<RestHandler> handlerWrapper;
/** Rest headers that are copied to internal requests made during a rest request. */
private final Set<String> headersToCopy;
// non volatile since the assumption is that pre processors are registered on startup
private RestFilter[] filters = new RestFilter[0];
public RestController(Settings settings, Set<String> headersToCopy) {
public RestController(Settings settings, Set<String> headersToCopy, UnaryOperator<RestHandler> handlerWrapper) {
super(settings);
this.headersToCopy = headersToCopy;
}
@Override
protected void doStart() {
}
@Override
protected void doStop() {
}
@Override
protected void doClose() {
for (RestFilter filter : filters) {
filter.close();
if (handlerWrapper == null) {
handlerWrapper = h -> h; // passthrough if no wrapper set
}
}
/**
* Registers a pre processor to be executed before the rest request is actually handled.
*/
public synchronized void registerFilter(RestFilter preProcessor) {
RestFilter[] copy = new RestFilter[filters.length + 1];
System.arraycopy(filters, 0, copy, 0, filters.length);
copy[filters.length] = preProcessor;
Arrays.sort(copy, (o1, o2) -> Integer.compare(o1.order(), o2.order()));
filters = copy;
this.handlerWrapper = handlerWrapper;
}
/**
@ -154,25 +128,6 @@ public class RestController extends AbstractLifecycleComponent {
}
}
/**
* Returns a filter chain (if needed) to execute. If this method returns null, simply execute
* as usual.
*/
@Nullable
public RestFilterChain filterChainOrNull(RestFilter executionFilter) {
if (filters.length == 0) {
return null;
}
return new ControllerFilterChain(executionFilter);
}
/**
* Returns a filter chain with the final filter being the provided filter.
*/
public RestFilterChain filterChain(RestFilter executionFilter) {
return new ControllerFilterChain(executionFilter);
}
/**
* @param request The current request. Must not be null.
* @return true iff the circuit breaker limit must be enforced for processing this request.
@ -193,21 +148,21 @@ public class RestController extends AbstractLifecycleComponent {
threadContext.putHeader(key, httpHeader);
}
}
if (filters.length == 0) {
executeHandler(request, channel, client);
} else {
ControllerFilterChain filterChain = new ControllerFilterChain(handlerFilter);
filterChain.continueProcessing(request, channel, client);
}
}
}
public void sendErrorResponse(RestRequest request, RestChannel channel, Exception e) {
try {
channel.sendResponse(new BytesRestResponse(channel, e));
} catch (Exception inner) {
inner.addSuppressed(e);
logger.error((Supplier<?>) () -> new ParameterizedMessage("failed to send failure response for uri [{}]", request.uri()), inner);
final RestHandler handler = getHandler(request);
if (handler == null) {
if (request.method() == RestRequest.Method.OPTIONS) {
// when we have OPTIONS request, simply send OK by default (with the Access Control Origin header which gets automatically added)
channel.sendResponse(new BytesRestResponse(OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
} else {
final String msg = "No handler found for uri [" + request.uri() + "] and method [" + request.method() + "]";
channel.sendResponse(new BytesRestResponse(BAD_REQUEST, msg));
}
} else {
final RestHandler wrappedHandler = Objects.requireNonNull(handlerWrapper.apply(handler));
wrappedHandler.handleRequest(request, channel, client);
}
}
}
@ -234,21 +189,6 @@ public class RestController extends AbstractLifecycleComponent {
return true;
}
void executeHandler(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
final RestHandler handler = getHandler(request);
if (handler != null) {
handler.handleRequest(request, channel, client);
} else {
if (request.method() == RestRequest.Method.OPTIONS) {
// when we have OPTIONS request, simply send OK by default (with the Access Control Origin header which gets automatically added)
channel.sendResponse(new BytesRestResponse(OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY));
} else {
final String msg = "No handler found for uri [" + request.uri() + "] and method [" + request.method() + "]";
channel.sendResponse(new BytesRestResponse(BAD_REQUEST, msg));
}
}
}
private RestHandler getHandler(RestRequest request) {
String path = getPath(request);
PathTrie<RestHandler> handlers = getHandlersForMethod(request.method());
@ -283,44 +223,4 @@ public class RestController extends AbstractLifecycleComponent {
// my_index/my_type/http%3A%2F%2Fwww.google.com
return request.rawPath();
}
class ControllerFilterChain implements RestFilterChain {
private final RestFilter executionFilter;
private final AtomicInteger index = new AtomicInteger();
ControllerFilterChain(RestFilter executionFilter) {
this.executionFilter = executionFilter;
}
@Override
public void continueProcessing(RestRequest request, RestChannel channel, NodeClient client) {
try {
int loc = index.getAndIncrement();
if (loc > filters.length) {
throw new IllegalStateException("filter continueProcessing was called more than expected");
} else if (loc == filters.length) {
executionFilter.process(request, channel, client, this);
} else {
RestFilter preProcessor = filters[loc];
preProcessor.process(request, channel, client, this);
}
} catch (Exception e) {
try {
channel.sendResponse(new BytesRestResponse(channel, e));
} catch (IOException e1) {
logger.error((Supplier<?>) () -> new ParameterizedMessage("Failed to send failure response for uri [{}]", request.uri()), e1);
}
}
}
}
class RestHandlerFilter extends RestFilter {
@Override
public void process(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception {
executeHandler(request, channel, client);
}
}
}

View File

@ -1,49 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest;
import java.io.Closeable;
import org.elasticsearch.client.node.NodeClient;
/**
* A filter allowing to filter rest operations.
*/
public abstract class RestFilter implements Closeable {
/**
* Optionally, the order of the filter. Execution is done from lowest value to highest.
* It is a good practice to allow to configure this for the relevant filter.
*/
public int order() {
return 0;
}
@Override
public void close() {
// a no op
}
/**
* Process the rest request. Using the channel to send a response, or the filter chain to continue
* processing the request.
*/
public abstract void process(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception;
}

View File

@ -1,34 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest;
import org.elasticsearch.client.node.NodeClient;
/**
* A filter chain allowing to continue and process the rest request.
*/
public interface RestFilterChain {
/**
* Continue processing the request. Should only be called if a response has not been sent
* through the channel.
*/
void continueProcessing(RestRequest request, RestChannel channel, NodeClient client);
}

View File

@ -59,7 +59,7 @@ public class HttpServerTests extends ESTestCase {
inFlightRequestsBreaker = circuitBreakerService.getBreaker(CircuitBreaker.IN_FLIGHT_REQUESTS);
HttpServerTransport httpServerTransport = new TestHttpServerTransport();
RestController restController = new RestController(settings, Collections.emptySet());
RestController restController = new RestController(settings, Collections.emptySet(), null);
restController.registerHandler(RestRequest.Method.GET, "/",
(request, channel, client) -> channel.sendResponse(
new BytesRestResponse(RestStatus.OK, BytesRestResponse.TEXT_CONTENT_TYPE, BytesArray.EMPTY)));

View File

@ -25,6 +25,8 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.UnaryOperator;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.logging.DeprecationLogger;
@ -44,19 +46,13 @@ public class RestControllerTests extends ESTestCase {
public void testApplyRelevantHeaders() throws Exception {
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
Set<String> headers = new HashSet<>(Arrays.asList("header.1", "header.2"));
final RestController restController = new RestController(Settings.EMPTY, headers) {
@Override
boolean checkRequestParameters(RestRequest request, RestChannel channel) {
return true;
}
@Override
void executeHandler(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
final RestController restController = new RestController(Settings.EMPTY, headers, null);
restController.registerHandler(RestRequest.Method.GET, "/",
(RestRequest request, RestChannel channel, NodeClient client) -> {
assertEquals("true", threadContext.getHeader("header.1"));
assertEquals("true", threadContext.getHeader("header.2"));
assertNull(threadContext.getHeader("header.3"));
}
};
});
threadContext.putHeader("header.3", "true");
Map<String, String> restHeaders = new HashMap<>();
restHeaders.put("header.1", "true");
@ -69,7 +65,7 @@ public class RestControllerTests extends ESTestCase {
}
public void testCanTripCircuitBreaker() throws Exception {
RestController controller = new RestController(Settings.EMPTY, Collections.emptySet());
RestController controller = new RestController(Settings.EMPTY, Collections.emptySet(), null);
// trip circuit breaker by default
controller.registerHandler(RestRequest.Method.GET, "/trip", new FakeRestHandler(true));
controller.registerHandler(RestRequest.Method.GET, "/do-not-trip", new FakeRestHandler(false));
@ -119,6 +115,24 @@ public class RestControllerTests extends ESTestCase {
verify(controller).registerAsDeprecatedHandler(deprecatedMethod, deprecatedPath, handler, deprecationMessage, logger);
}
public void testRestHandlerWrapper() throws Exception {
AtomicBoolean handlerCalled = new AtomicBoolean(false);
AtomicBoolean wrapperCalled = new AtomicBoolean(false);
RestHandler handler = (RestRequest request, RestChannel channel, NodeClient client) -> {
handlerCalled.set(true);
};
UnaryOperator<RestHandler> wrapper = h -> {
assertSame(handler, h);
return (RestRequest request, RestChannel channel, NodeClient client) -> wrapperCalled.set(true);
};
final RestController restController = new RestController(Settings.EMPTY, Collections.emptySet(), wrapper);
restController.registerHandler(RestRequest.Method.GET, "/", handler);
final ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
restController.dispatchRequest(new FakeRestRequest.Builder().build(), null, null, threadContext);
assertTrue(wrapperCalled.get());
assertFalse(handlerCalled.get());
}
/**
* Useful for testing with deprecation handler.
*/

View File

@ -1,211 +0,0 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.rest;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestChannel;
import org.elasticsearch.test.rest.FakeRestRequest;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import static org.hamcrest.CoreMatchers.equalTo;
public class RestFilterChainTests extends ESTestCase {
public void testRestFilters() throws Exception {
RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters);
while (orders.size() < numFilters) {
orders.add(randomInt(10));
}
List<RestFilter> filters = new ArrayList<>();
for (Integer order : orders) {
TestFilter testFilter = new TestFilter(order, randomFrom(Operation.values()));
filters.add(testFilter);
restController.registerFilter(testFilter);
}
ArrayList<RestFilter> restFiltersByOrder = new ArrayList<>(filters);
Collections.sort(restFiltersByOrder, new Comparator<RestFilter>() {
@Override
public int compare(RestFilter o1, RestFilter o2) {
return Integer.compare(o1.order(), o2.order());
}
});
List<RestFilter> expectedRestFilters = new ArrayList<>();
for (RestFilter filter : restFiltersByOrder) {
TestFilter testFilter = (TestFilter) filter;
expectedRestFilters.add(testFilter);
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) {
break;
}
}
restController.registerHandler(RestRequest.Method.GET, "/", (request, channel, client) -> {
channel.sendResponse(new TestResponse());
});
FakeRestRequest fakeRestRequest = new FakeRestRequest();
FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, randomBoolean(), 1);
restController.dispatchRequest(fakeRestRequest, fakeRestChannel, null, new ThreadContext(Settings.EMPTY));
assertThat(fakeRestChannel.await(), equalTo(true));
List<TestFilter> testFiltersByLastExecution = new ArrayList<>();
for (RestFilter restFilter : filters) {
testFiltersByLastExecution.add((TestFilter)restFilter);
}
Collections.sort(testFiltersByLastExecution, new Comparator<TestFilter>() {
@Override
public int compare(TestFilter o1, TestFilter o2) {
return Long.compare(o1.executionToken, o2.executionToken);
}
});
ArrayList<TestFilter> finalTestFilters = new ArrayList<>();
for (RestFilter filter : testFiltersByLastExecution) {
TestFilter testFilter = (TestFilter) filter;
finalTestFilters.add(testFilter);
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) {
break;
}
}
assertThat(finalTestFilters.size(), equalTo(expectedRestFilters.size()));
for (int i = 0; i < finalTestFilters.size(); i++) {
TestFilter testFilter = finalTestFilters.get(i);
assertThat(testFilter, equalTo(expectedRestFilters.get(i)));
assertThat(testFilter.runs.get(), equalTo(1));
}
}
public void testTooManyContinueProcessing() throws Exception {
final int additionalContinueCount = randomInt(10);
TestFilter testFilter = new TestFilter(randomInt(), (request, channel, client, filterChain) -> {
for (int i = 0; i <= additionalContinueCount; i++) {
filterChain.continueProcessing(request, channel, null);
}
});
RestController restController = new RestController(Settings.EMPTY, Collections.emptySet());
restController.registerFilter(testFilter);
restController.registerHandler(RestRequest.Method.GET, "/", new RestHandler() {
@Override
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
channel.sendResponse(new TestResponse());
}
});
FakeRestRequest fakeRestRequest = new FakeRestRequest();
FakeRestChannel fakeRestChannel = new FakeRestChannel(fakeRestRequest, randomBoolean(), additionalContinueCount + 1);
restController.dispatchRequest(fakeRestRequest, fakeRestChannel, null, new ThreadContext(Settings.EMPTY));
fakeRestChannel.await();
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(fakeRestChannel.responses().get(), equalTo(1));
assertThat(fakeRestChannel.errors().get(), equalTo(additionalContinueCount));
}
private enum Operation implements Callback {
CONTINUE_PROCESSING {
@Override
public void execute(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception {
filterChain.continueProcessing(request, channel, client);
}
},
CHANNEL_RESPONSE {
@Override
public void execute(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception {
channel.sendResponse(new TestResponse());
}
}
}
private interface Callback {
void execute(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception;
}
private final AtomicInteger counter = new AtomicInteger();
private class TestFilter extends RestFilter {
private final int order;
private final Callback callback;
AtomicInteger runs = new AtomicInteger();
volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list
TestFilter(int order, Callback callback) {
this.order = order;
this.callback = callback;
}
@Override
public void process(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception {
this.runs.incrementAndGet();
this.executionToken = counter.incrementAndGet();
this.callback.execute(request, channel, client, filterChain);
}
@Override
public int order() {
return order;
}
@Override
public String toString() {
return "[order:" + order + ", executionToken:" + executionToken + "]";
}
}
private static class TestResponse extends RestResponse {
@Override
public String contentType() {
return null;
}
@Override
public BytesReference content() {
return null;
}
@Override
public RestStatus status() {
return RestStatus.OK;
}
}
}

View File

@ -43,7 +43,7 @@ public class RestNodesStatsActionTests extends ESTestCase {
@Override
public void setUp() throws Exception {
super.setUp();
action = new RestNodesStatsAction(Settings.EMPTY, new RestController(Settings.EMPTY, Collections.emptySet()));
action = new RestNodesStatsAction(Settings.EMPTY, new RestController(Settings.EMPTY, Collections.emptySet(), null));
}
public void testUnrecognizedMetric() throws IOException {

View File

@ -41,7 +41,7 @@ public class RestIndicesStatsActionTests extends ESTestCase {
@Override
public void setUp() throws Exception {
super.setUp();
action = new RestIndicesStatsAction(Settings.EMPTY, new RestController(Settings.EMPTY, Collections.emptySet()));
action = new RestIndicesStatsAction(Settings.EMPTY, new RestController(Settings.EMPTY, Collections.emptySet(), null));
}
public void testUnrecognizedMetric() throws IOException {

View File

@ -74,7 +74,7 @@ public class RestIndicesActionTests extends ESTestCase {
public void testBuildTable() {
final Settings settings = Settings.EMPTY;
final RestController restController = new RestController(settings, Collections.emptySet());
final RestController restController = new RestController(settings, Collections.emptySet(), null);
final RestIndicesAction action = new RestIndicesAction(settings, restController, new IndexNameExpressionResolver(settings));
// build a (semi-)random table

View File

@ -50,7 +50,7 @@ public class RestRecoveryActionTests extends ESTestCase {
public void testRestRecoveryAction() {
final Settings settings = Settings.EMPTY;
final RestController restController = new RestController(settings, Collections.emptySet());
final RestController restController = new RestController(settings, Collections.emptySet(), null);
final RestRecoveryAction action = new RestRecoveryAction(settings, restController, restController);
final int totalShards = randomIntBetween(1, 32);
final int successfulShards = Math.max(0, totalShards - randomIntBetween(1, 2));