diff --git a/src/main/java/org/elasticsearch/action/support/ActionFilter.java b/src/main/java/org/elasticsearch/action/support/ActionFilter.java index b50d8217c4b..e2bc2f0aaf0 100644 --- a/src/main/java/org/elasticsearch/action/support/ActionFilter.java +++ b/src/main/java/org/elasticsearch/action/support/ActionFilter.java @@ -21,20 +21,67 @@ package org.elasticsearch.action.support; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.common.component.AbstractComponent; +import org.elasticsearch.common.settings.Settings; /** * A filter allowing to filter transport actions */ public interface ActionFilter { - /** - * Filters the actual execution of the request by either sending a response through the {@link ActionListener} - * or continuing the filters execution through the {@link ActionFilterChain} - */ - void process(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain); - /** * The position of the filter in the chain. Execution is done from lowest order to highest. */ int order(); + + /** + * Enables filtering the execution of an action on the request side, either by sending a response through the + * {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain} + */ + void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain); + + /** + * Enables filtering the execution of an action on the response side, either by sending a response through the + * {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain} + */ + void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain); + + /** + * A simple base class for injectable action filters that spares the implementation from handling the + * filter chain. This base class should serve any action filter implementations that doesn't require + * to apply async filtering logic. + */ + public static abstract class Simple extends AbstractComponent implements ActionFilter { + + protected Simple(Settings settings) { + super(settings); + } + + @Override + public final void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { + if (apply(action, request, listener)) { + chain.proceed(action, request, listener); + } + } + + /** + * Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false} + * if it should be aborted since the filter already handled the request and called the given listener. + */ + protected abstract boolean apply(String action, ActionRequest request, ActionListener listener); + + @Override + public final void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + if (apply(action, response, listener)) { + chain.proceed(action, response, listener); + } + } + + /** + * Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false} + * if it should be aborted since the filter already handled the response by calling the given listener. + */ + protected abstract boolean apply(String action, ActionResponse response, ActionListener listener); + } } diff --git a/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java b/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java index 9dc40cece27..d90d6b5a98e 100644 --- a/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java +++ b/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java @@ -21,6 +21,7 @@ package org.elasticsearch.action.support; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.action.ActionResponse; /** * A filter chain allowing to continue and process the transport action request @@ -28,7 +29,14 @@ import org.elasticsearch.action.ActionRequest; public interface ActionFilterChain { /** - * Continue processing the request. Should only be called if a response has not been sent through the {@link ActionListener} + * Continue processing the request. Should only be called if a response has not been sent through + * the given {@link ActionListener listener} */ - void continueProcessing(final String action, final ActionRequest request, final ActionListener actionListener); + void proceed(final String action, final ActionRequest request, final ActionListener listener); + + /** + * Continue processing the response. Should only be called if a response has not been sent through + * the given {@link ActionListener listener} + */ + void proceed(final String action, final ActionResponse response, final ActionListener listener); } diff --git a/src/main/java/org/elasticsearch/action/support/TransportAction.java b/src/main/java/org/elasticsearch/action/support/TransportAction.java index 1e3b8a8eb53..9e849fd8b40 100644 --- a/src/main/java/org/elasticsearch/action/support/TransportAction.java +++ b/src/main/java/org/elasticsearch/action/support/TransportAction.java @@ -78,8 +78,8 @@ public abstract class TransportAction(this, logger); + requestFilterChain.proceed(actionName, request, listener); } } @@ -146,28 +146,95 @@ public abstract class TransportAction implements ActionFilterChain { + private final TransportAction action; private final AtomicInteger index = new AtomicInteger(); + private final ESLogger logger; - @SuppressWarnings("unchecked") - @Override - public void continueProcessing(String action, ActionRequest actionRequest, ActionListener actionListener) { + private RequestFilterChain(TransportAction action, ESLogger logger) { + this.action = action; + this.logger = logger; + } + + @Override @SuppressWarnings("unchecked") + public void proceed(String actionName, ActionRequest request, ActionListener listener) { int i = index.getAndIncrement(); try { - if (i < filters.length) { - filters[i].process(action, actionRequest, actionListener, this); - } else if (i == filters.length) { - ActionListener listener = (ActionListener) actionListener; - Request request = (Request) actionRequest; - doExecute(request, listener); + if (i < this.action.filters.length) { + this.action.filters[i].apply(actionName, request, listener, this); + } else if (i == this.action.filters.length) { + this.action.doExecute((Request) request, new FilteredActionListener(actionName, listener, new ResponseFilterChain(this.action.filters, logger))); } else { - actionListener.onFailure(new IllegalStateException("continueProcessing was called too many times")); + listener.onFailure(new IllegalStateException("proceed was called too many times")); } } catch(Throwable t) { logger.trace("Error during transport action execution.", t); - actionListener.onFailure(t); + listener.onFailure(t); + } + } + + @Override + public void proceed(String action, ActionResponse response, ActionListener listener) { + assert false : "request filter chain should never be called on the response side"; + } + } + + private static class ResponseFilterChain implements ActionFilterChain { + + private final ActionFilter[] filters; + private final AtomicInteger index; + private final ESLogger logger; + + private ResponseFilterChain(ActionFilter[] filters, ESLogger logger) { + this.filters = filters; + this.index = new AtomicInteger(filters.length); + this.logger = logger; + } + + @Override + public void proceed(String action, ActionRequest request, ActionListener listener) { + assert false : "response filter chain should never be called on the request side"; + } + + @Override @SuppressWarnings("unchecked") + public void proceed(String action, ActionResponse response, ActionListener listener) { + int i = index.decrementAndGet(); + try { + if (i >= 0) { + filters[i].apply(action, response, listener, this); + } else if (i == -1) { + listener.onResponse(response); + } else { + listener.onFailure(new IllegalStateException("proceed was called too many times")); + } + } catch (Throwable t) { + logger.trace("Error during transport action execution.", t); + listener.onFailure(t); } } } + + private static class FilteredActionListener implements ActionListener { + + private final String actionName; + private final ActionListener listener; + private final ResponseFilterChain chain; + + private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) { + this.actionName = actionName; + this.listener = listener; + this.chain = chain; + } + + @Override + public void onResponse(Response response) { + chain.proceed(actionName, response, listener); + } + + @Override + public void onFailure(Throwable e) { + listener.onFailure(e); + } + } } diff --git a/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java b/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java index 433165afb06..1eccf3521f0 100644 --- a/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java +++ b/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.test.ElasticsearchTestCase; +import org.junit.Before; import org.junit.Test; import java.util.*; @@ -40,8 +41,15 @@ import static org.hamcrest.CoreMatchers.*; public class TransportActionFilterChainTests extends ElasticsearchTestCase { + private AtomicInteger counter; + + @Before + public void init() throws Exception { + counter = new AtomicInteger(); + } + @Test - public void testActionFilters() throws ExecutionException, InterruptedException { + public void testActionFiltersRequest() throws ExecutionException, InterruptedException { int numFilters = randomInt(10); Set orders = new HashSet<>(numFilters); @@ -51,7 +59,7 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { Set filters = new HashSet<>(); for (Integer order : orders) { - filters.add(new TestFilter(order, randomFrom(Operation.values()))); + filters.add(new RequestTestFilter(order, randomFrom(RequestOperation.values()))); } String actionName = randomAsciiOfLength(randomInt(30)); @@ -74,12 +82,12 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { List expectedActionFilters = Lists.newArrayList(); boolean errorExpected = false; for (ActionFilter filter : actionFiltersByOrder) { - TestFilter testFilter = (TestFilter) filter; + RequestTestFilter testFilter = (RequestTestFilter) filter; expectedActionFilters.add(testFilter); - if (testFilter.callback == Operation.LISTENER_FAILURE) { + if (testFilter.callback == RequestOperation.LISTENER_FAILURE) { errorExpected = true; } - if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { + if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } @@ -93,29 +101,29 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true)); } - List testFiltersByLastExecution = Lists.newArrayList(); + List testFiltersByLastExecution = Lists.newArrayList(); for (ActionFilter actionFilter : actionFilters.filters()) { - testFiltersByLastExecution.add((TestFilter) actionFilter); + testFiltersByLastExecution.add((RequestTestFilter) actionFilter); } - Collections.sort(testFiltersByLastExecution, new Comparator() { + Collections.sort(testFiltersByLastExecution, new Comparator() { @Override - public int compare(TestFilter o1, TestFilter o2) { + public int compare(RequestTestFilter o1, RequestTestFilter o2) { return Integer.compare(o1.executionToken, o2.executionToken); } }); - ArrayList finalTestFilters = Lists.newArrayList(); + ArrayList finalTestFilters = Lists.newArrayList(); for (ActionFilter filter : testFiltersByLastExecution) { - TestFilter testFilter = (TestFilter) filter; + RequestTestFilter testFilter = (RequestTestFilter) filter; finalTestFilters.add(testFilter); - if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { + if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) { break; } } assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); for (int i = 0; i < finalTestFilters.size(); i++) { - TestFilter testFilter = finalTestFilters.get(i); + RequestTestFilter testFilter = finalTestFilters.get(i); assertThat(testFilter, equalTo(expectedActionFilters.get(i))); assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.lastActionName, equalTo(actionName)); @@ -123,15 +131,97 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { } @Test - public void testTooManyContinueProcessing() throws ExecutionException, InterruptedException { + public void testActionFiltersResponse() throws ExecutionException, InterruptedException { + + int numFilters = randomInt(10); + Set orders = new HashSet<>(numFilters); + while (orders.size() < numFilters) { + orders.add(randomInt(10)); + } + + Set filters = new HashSet<>(); + for (Integer order : orders) { + filters.add(new ResponseTestFilter(order, randomFrom(ResponseOperation.values()))); + } + + String actionName = randomAsciiOfLength(randomInt(30)); + ActionFilters actionFilters = new ActionFilters(filters); + TransportAction transportAction = new TransportAction(ImmutableSettings.EMPTY, actionName, null, actionFilters) { + @Override + protected void doExecute(TestRequest request, ActionListener listener) { + listener.onResponse(new TestResponse()); + } + }; + + ArrayList actionFiltersByOrder = Lists.newArrayList(filters); + Collections.sort(actionFiltersByOrder, new Comparator() { + @Override + public int compare(ActionFilter o1, ActionFilter o2) { + return Integer.compare(o2.order(), o1.order()); + } + }); + + List expectedActionFilters = Lists.newArrayList(); + boolean errorExpected = false; + for (ActionFilter filter : actionFiltersByOrder) { + ResponseTestFilter testFilter = (ResponseTestFilter) filter; + expectedActionFilters.add(testFilter); + if (testFilter.callback == ResponseOperation.LISTENER_FAILURE) { + errorExpected = true; + } + if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) { + break; + } + } + + PlainListenableActionFuture future = new PlainListenableActionFuture<>(false, null); + transportAction.execute(new TestRequest(), future); + try { + assertThat(future.get(), notNullValue()); + assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false)); + } catch(Throwable t) { + assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true)); + } + + List testFiltersByLastExecution = Lists.newArrayList(); + for (ActionFilter actionFilter : actionFilters.filters()) { + testFiltersByLastExecution.add((ResponseTestFilter) actionFilter); + } + Collections.sort(testFiltersByLastExecution, new Comparator() { + @Override + public int compare(ResponseTestFilter o1, ResponseTestFilter o2) { + return Integer.compare(o1.executionToken, o2.executionToken); + } + }); + + ArrayList finalTestFilters = Lists.newArrayList(); + for (ActionFilter filter : testFiltersByLastExecution) { + ResponseTestFilter testFilter = (ResponseTestFilter) filter; + finalTestFilters.add(testFilter); + if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) { + break; + } + } + + assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); + for (int i = 0; i < finalTestFilters.size(); i++) { + ResponseTestFilter testFilter = finalTestFilters.get(i); + assertThat(testFilter, equalTo(expectedActionFilters.get(i))); + assertThat(testFilter.runs.get(), equalTo(1)); + assertThat(testFilter.lastActionName, equalTo(actionName)); + } + } + + @Test + public void testTooManyContinueProcessingRequest() throws ExecutionException, InterruptedException { final int additionalContinueCount = randomInt(10); - TestFilter testFilter = new TestFilter(randomInt(), new Callback() { + RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() { @Override public void execute(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) { for (int i = 0; i <= additionalContinueCount; i++) { - actionFilterChain.continueProcessing(action, actionRequest, actionListener); + actionFilterChain.proceed(action, actionRequest, actionListener); } } }); @@ -180,24 +270,84 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { } } - private final AtomicInteger counter = new AtomicInteger(); + @Test + public void testTooManyContinueProcessingResponse() throws ExecutionException, InterruptedException { - private class TestFilter implements ActionFilter { + final int additionalContinueCount = randomInt(10); + + ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() { + @Override + public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + for (int i = 0; i <= additionalContinueCount; i++) { + chain.proceed(action, response, listener); + } + } + }); + + Set filters = new HashSet<>(); + filters.add(testFilter); + + String actionName = randomAsciiOfLength(randomInt(30)); + ActionFilters actionFilters = new ActionFilters(filters); + TransportAction transportAction = new TransportAction(ImmutableSettings.EMPTY, actionName, null, actionFilters) { + @Override + protected void doExecute(TestRequest request, ActionListener listener) { + listener.onResponse(new TestResponse()); + } + }; + + final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1); + final AtomicInteger responses = new AtomicInteger(); + final List failures = new CopyOnWriteArrayList<>(); + + transportAction.execute(new TestRequest(), new ActionListener() { + @Override + public void onResponse(TestResponse testResponse) { + responses.incrementAndGet(); + latch.countDown(); + } + + @Override + public void onFailure(Throwable e) { + failures.add(e); + latch.countDown(); + } + }); + + if (!latch.await(10, TimeUnit.SECONDS)) { + fail("timeout waiting for the filter to notify the listener as many times as expected"); + } + + assertThat(testFilter.runs.get(), equalTo(1)); + assertThat(testFilter.lastActionName, equalTo(actionName)); + + assertThat(responses.get(), equalTo(1)); + assertThat(failures.size(), equalTo(additionalContinueCount)); + for (Throwable failure : failures) { + assertThat(failure, instanceOf(IllegalStateException.class)); + } + } + + private class RequestTestFilter implements ActionFilter { + private final RequestCallback callback; private final int order; - private final Callback callback; - AtomicInteger runs = new AtomicInteger(); volatile String lastActionName; volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list - TestFilter(int order, Callback callback) { + RequestTestFilter(int order, RequestCallback callback) { this.order = order; this.callback = callback; } + @Override + public int order() { + return order; + } + @SuppressWarnings("unchecked") @Override - public void process(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { + public void apply(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); @@ -205,16 +355,47 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { } @Override - public int order() { - return order; + public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + chain.proceed(action, response, listener); } } - private static enum Operation implements Callback { + private class ResponseTestFilter implements ActionFilter { + private final ResponseCallback callback; + private final int order; + AtomicInteger runs = new AtomicInteger(); + volatile String lastActionName; + volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list + + ResponseTestFilter(int order, ResponseCallback callback) { + this.order = order; + this.callback = callback; + } + + @Override + public int order() { + return order; + } + + @Override + public void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { + chain.proceed(action, request, listener); + } + + @Override + public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + this.runs.incrementAndGet(); + this.lastActionName = action; + this.executionToken = counter.incrementAndGet(); + this.callback.execute(action, response, listener, chain); + } + } + + private static enum RequestOperation implements RequestCallback { CONTINUE_PROCESSING { @Override public void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { - actionFilterChain.continueProcessing(action, actionRequest, actionListener); + actionFilterChain.proceed(action, actionRequest, actionListener); } }, LISTENER_RESPONSE { @@ -232,10 +413,36 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase { } } - private static interface Callback { + private static enum ResponseOperation implements ResponseCallback { + CONTINUE_PROCESSING { + @Override + public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + chain.proceed(action, response, listener); + } + }, + LISTENER_RESPONSE { + @Override + @SuppressWarnings("unchecked") + public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + listener.onResponse(new TestResponse()); + } + }, + LISTENER_FAILURE { + @Override + public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + listener.onFailure(new ElasticsearchTimeoutException("")); + } + } + } + + private static interface RequestCallback { void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain); } + private static interface ResponseCallback { + void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain); + } + private static class TestRequest extends ActionRequest { @Override public ActionRequestValidationException validate() {