Extended ActionFilter to also enable filtering the response side

Enables filtering the actions on both sides - request and response. Also added a base class for filter implementations (cleans up filters that only need to filter one side)

Also refactored the filter & filter chain methods to more intuitive names
This commit is contained in:
uboness 2014-08-26 10:39:45 -07:00
parent dd54025b17
commit 333a39cf30
4 changed files with 378 additions and 49 deletions

View File

@ -21,20 +21,67 @@ package org.elasticsearch.action.support;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.settings.Settings;
/** /**
* A filter allowing to filter transport actions * A filter allowing to filter transport actions
*/ */
public interface ActionFilter { public interface ActionFilter {
/**
* Filters the actual execution of the request by either sending a response through the {@link ActionListener}
* or continuing the filters execution through the {@link ActionFilterChain}
*/
void process(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain);
/** /**
* The position of the filter in the chain. Execution is done from lowest order to highest. * The position of the filter in the chain. Execution is done from lowest order to highest.
*/ */
int order(); int order();
/**
* Enables filtering the execution of an action on the request side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain);
/**
* Enables filtering the execution of an action on the response side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain);
/**
* A simple base class for injectable action filters that spares the implementation from handling the
* filter chain. This base class should serve any action filter implementations that doesn't require
* to apply async filtering logic.
*/
public static abstract class Simple extends AbstractComponent implements ActionFilter {
protected Simple(Settings settings) {
super(settings);
}
@Override
public final void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
if (apply(action, request, listener)) {
chain.proceed(action, request, listener);
}
}
/**
* Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false}
* if it should be aborted since the filter already handled the request and called the given listener.
*/
protected abstract boolean apply(String action, ActionRequest request, ActionListener listener);
@Override
public final void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
if (apply(action, response, listener)) {
chain.proceed(action, response, listener);
}
}
/**
* Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false}
* if it should be aborted since the filter already handled the response by calling the given listener.
*/
protected abstract boolean apply(String action, ActionResponse response, ActionListener listener);
}
} }

View File

@ -21,6 +21,7 @@ package org.elasticsearch.action.support;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
/** /**
* A filter chain allowing to continue and process the transport action request * A filter chain allowing to continue and process the transport action request
@ -28,7 +29,14 @@ import org.elasticsearch.action.ActionRequest;
public interface ActionFilterChain { public interface ActionFilterChain {
/** /**
* Continue processing the request. Should only be called if a response has not been sent through the {@link ActionListener} * Continue processing the request. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/ */
void continueProcessing(final String action, final ActionRequest request, final ActionListener actionListener); void proceed(final String action, final ActionRequest request, final ActionListener listener);
/**
* Continue processing the response. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/
void proceed(final String action, final ActionResponse response, final ActionListener listener);
} }

View File

@ -78,8 +78,8 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
listener.onFailure(t); listener.onFailure(t);
} }
} else { } else {
ActionFilterChain actionFilterChain = new TransportActionFilterChain(); RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger);
actionFilterChain.continueProcessing(actionName, request, listener); requestFilterChain.proceed(actionName, request, listener);
} }
} }
@ -146,28 +146,95 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
} }
} }
private class TransportActionFilterChain implements ActionFilterChain { private static class RequestFilterChain<Request extends ActionRequest, Response extends ActionResponse> implements ActionFilterChain {
private final TransportAction<Request, Response> action;
private final AtomicInteger index = new AtomicInteger(); private final AtomicInteger index = new AtomicInteger();
private final ESLogger logger;
@SuppressWarnings("unchecked") private RequestFilterChain(TransportAction<Request, Response> action, ESLogger logger) {
@Override this.action = action;
public void continueProcessing(String action, ActionRequest actionRequest, ActionListener actionListener) { this.logger = logger;
}
@Override @SuppressWarnings("unchecked")
public void proceed(String actionName, ActionRequest request, ActionListener listener) {
int i = index.getAndIncrement(); int i = index.getAndIncrement();
try { try {
if (i < filters.length) { if (i < this.action.filters.length) {
filters[i].process(action, actionRequest, actionListener, this); this.action.filters[i].apply(actionName, request, listener, this);
} else if (i == filters.length) { } else if (i == this.action.filters.length) {
ActionListener<Response> listener = (ActionListener<Response>) actionListener; this.action.doExecute((Request) request, new FilteredActionListener<Response>(actionName, listener, new ResponseFilterChain(this.action.filters, logger)));
Request request = (Request) actionRequest;
doExecute(request, listener);
} else { } else {
actionListener.onFailure(new IllegalStateException("continueProcessing was called too many times")); listener.onFailure(new IllegalStateException("proceed was called too many times"));
} }
} catch(Throwable t) { } catch(Throwable t) {
logger.trace("Error during transport action execution.", t); logger.trace("Error during transport action execution.", t);
actionListener.onFailure(t); listener.onFailure(t);
}
}
@Override
public void proceed(String action, ActionResponse response, ActionListener listener) {
assert false : "request filter chain should never be called on the response side";
}
}
private static class ResponseFilterChain implements ActionFilterChain {
private final ActionFilter[] filters;
private final AtomicInteger index;
private final ESLogger logger;
private ResponseFilterChain(ActionFilter[] filters, ESLogger logger) {
this.filters = filters;
this.index = new AtomicInteger(filters.length);
this.logger = logger;
}
@Override
public void proceed(String action, ActionRequest request, ActionListener listener) {
assert false : "response filter chain should never be called on the request side";
}
@Override @SuppressWarnings("unchecked")
public void proceed(String action, ActionResponse response, ActionListener listener) {
int i = index.decrementAndGet();
try {
if (i >= 0) {
filters[i].apply(action, response, listener, this);
} else if (i == -1) {
listener.onResponse(response);
} else {
listener.onFailure(new IllegalStateException("proceed was called too many times"));
}
} catch (Throwable t) {
logger.trace("Error during transport action execution.", t);
listener.onFailure(t);
} }
} }
} }
private static class FilteredActionListener<Response extends ActionResponse> implements ActionListener<Response> {
private final String actionName;
private final ActionListener listener;
private final ResponseFilterChain chain;
private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) {
this.actionName = actionName;
this.listener = listener;
this.chain = chain;
}
@Override
public void onResponse(Response response) {
chain.proceed(actionName, response, listener);
}
@Override
public void onFailure(Throwable e) {
listener.onFailure(e);
}
}
} }

View File

@ -27,6 +27,7 @@ import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.settings.ImmutableSettings; import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.test.ElasticsearchTestCase; import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.*; import java.util.*;
@ -40,8 +41,15 @@ import static org.hamcrest.CoreMatchers.*;
public class TransportActionFilterChainTests extends ElasticsearchTestCase { public class TransportActionFilterChainTests extends ElasticsearchTestCase {
private AtomicInteger counter;
@Before
public void init() throws Exception {
counter = new AtomicInteger();
}
@Test @Test
public void testActionFilters() throws ExecutionException, InterruptedException { public void testActionFiltersRequest() throws ExecutionException, InterruptedException {
int numFilters = randomInt(10); int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters); Set<Integer> orders = new HashSet<>(numFilters);
@ -51,7 +59,7 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
Set<ActionFilter> filters = new HashSet<>(); Set<ActionFilter> filters = new HashSet<>();
for (Integer order : orders) { for (Integer order : orders) {
filters.add(new TestFilter(order, randomFrom(Operation.values()))); filters.add(new RequestTestFilter(order, randomFrom(RequestOperation.values())));
} }
String actionName = randomAsciiOfLength(randomInt(30)); String actionName = randomAsciiOfLength(randomInt(30));
@ -74,12 +82,12 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
List<ActionFilter> expectedActionFilters = Lists.newArrayList(); List<ActionFilter> expectedActionFilters = Lists.newArrayList();
boolean errorExpected = false; boolean errorExpected = false;
for (ActionFilter filter : actionFiltersByOrder) { for (ActionFilter filter : actionFiltersByOrder) {
TestFilter testFilter = (TestFilter) filter; RequestTestFilter testFilter = (RequestTestFilter) filter;
expectedActionFilters.add(testFilter); expectedActionFilters.add(testFilter);
if (testFilter.callback == Operation.LISTENER_FAILURE) { if (testFilter.callback == RequestOperation.LISTENER_FAILURE) {
errorExpected = true; errorExpected = true;
} }
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) {
break; break;
} }
} }
@ -93,29 +101,29 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true)); assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true));
} }
List<TestFilter> testFiltersByLastExecution = Lists.newArrayList(); List<RequestTestFilter> testFiltersByLastExecution = Lists.newArrayList();
for (ActionFilter actionFilter : actionFilters.filters()) { for (ActionFilter actionFilter : actionFilters.filters()) {
testFiltersByLastExecution.add((TestFilter) actionFilter); testFiltersByLastExecution.add((RequestTestFilter) actionFilter);
} }
Collections.sort(testFiltersByLastExecution, new Comparator<TestFilter>() { Collections.sort(testFiltersByLastExecution, new Comparator<RequestTestFilter>() {
@Override @Override
public int compare(TestFilter o1, TestFilter o2) { public int compare(RequestTestFilter o1, RequestTestFilter o2) {
return Integer.compare(o1.executionToken, o2.executionToken); return Integer.compare(o1.executionToken, o2.executionToken);
} }
}); });
ArrayList<TestFilter> finalTestFilters = Lists.newArrayList(); ArrayList<RequestTestFilter> finalTestFilters = Lists.newArrayList();
for (ActionFilter filter : testFiltersByLastExecution) { for (ActionFilter filter : testFiltersByLastExecution) {
TestFilter testFilter = (TestFilter) filter; RequestTestFilter testFilter = (RequestTestFilter) filter;
finalTestFilters.add(testFilter); finalTestFilters.add(testFilter);
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) { if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) {
break; break;
} }
} }
assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size())); assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size()));
for (int i = 0; i < finalTestFilters.size(); i++) { for (int i = 0; i < finalTestFilters.size(); i++) {
TestFilter testFilter = finalTestFilters.get(i); RequestTestFilter testFilter = finalTestFilters.get(i);
assertThat(testFilter, equalTo(expectedActionFilters.get(i))); assertThat(testFilter, equalTo(expectedActionFilters.get(i)));
assertThat(testFilter.runs.get(), equalTo(1)); assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName)); assertThat(testFilter.lastActionName, equalTo(actionName));
@ -123,15 +131,97 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
} }
@Test @Test
public void testTooManyContinueProcessing() throws ExecutionException, InterruptedException { public void testActionFiltersResponse() throws ExecutionException, InterruptedException {
int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters);
while (orders.size() < numFilters) {
orders.add(randomInt(10));
}
Set<ActionFilter> filters = new HashSet<>();
for (Integer order : orders) {
filters.add(new ResponseTestFilter(order, randomFrom(ResponseOperation.values())));
}
String actionName = randomAsciiOfLength(randomInt(30));
ActionFilters actionFilters = new ActionFilters(filters);
TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(ImmutableSettings.EMPTY, actionName, null, actionFilters) {
@Override
protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) {
listener.onResponse(new TestResponse());
}
};
ArrayList<ActionFilter> actionFiltersByOrder = Lists.newArrayList(filters);
Collections.sort(actionFiltersByOrder, new Comparator<ActionFilter>() {
@Override
public int compare(ActionFilter o1, ActionFilter o2) {
return Integer.compare(o2.order(), o1.order());
}
});
List<ActionFilter> expectedActionFilters = Lists.newArrayList();
boolean errorExpected = false;
for (ActionFilter filter : actionFiltersByOrder) {
ResponseTestFilter testFilter = (ResponseTestFilter) filter;
expectedActionFilters.add(testFilter);
if (testFilter.callback == ResponseOperation.LISTENER_FAILURE) {
errorExpected = true;
}
if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) {
break;
}
}
PlainListenableActionFuture<TestResponse> future = new PlainListenableActionFuture<>(false, null);
transportAction.execute(new TestRequest(), future);
try {
assertThat(future.get(), notNullValue());
assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false));
} catch(Throwable t) {
assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true));
}
List<ResponseTestFilter> testFiltersByLastExecution = Lists.newArrayList();
for (ActionFilter actionFilter : actionFilters.filters()) {
testFiltersByLastExecution.add((ResponseTestFilter) actionFilter);
}
Collections.sort(testFiltersByLastExecution, new Comparator<ResponseTestFilter>() {
@Override
public int compare(ResponseTestFilter o1, ResponseTestFilter o2) {
return Integer.compare(o1.executionToken, o2.executionToken);
}
});
ArrayList<ResponseTestFilter> finalTestFilters = Lists.newArrayList();
for (ActionFilter filter : testFiltersByLastExecution) {
ResponseTestFilter testFilter = (ResponseTestFilter) filter;
finalTestFilters.add(testFilter);
if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) {
break;
}
}
assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size()));
for (int i = 0; i < finalTestFilters.size(); i++) {
ResponseTestFilter testFilter = finalTestFilters.get(i);
assertThat(testFilter, equalTo(expectedActionFilters.get(i)));
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));
}
}
@Test
public void testTooManyContinueProcessingRequest() throws ExecutionException, InterruptedException {
final int additionalContinueCount = randomInt(10); final int additionalContinueCount = randomInt(10);
TestFilter testFilter = new TestFilter(randomInt(), new Callback() { RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() {
@Override @Override
public void execute(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) { public void execute(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) {
for (int i = 0; i <= additionalContinueCount; i++) { for (int i = 0; i <= additionalContinueCount; i++) {
actionFilterChain.continueProcessing(action, actionRequest, actionListener); actionFilterChain.proceed(action, actionRequest, actionListener);
} }
} }
}); });
@ -180,24 +270,84 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
} }
} }
private final AtomicInteger counter = new AtomicInteger(); @Test
public void testTooManyContinueProcessingResponse() throws ExecutionException, InterruptedException {
private class TestFilter implements ActionFilter { final int additionalContinueCount = randomInt(10);
ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
for (int i = 0; i <= additionalContinueCount; i++) {
chain.proceed(action, response, listener);
}
}
});
Set<ActionFilter> filters = new HashSet<>();
filters.add(testFilter);
String actionName = randomAsciiOfLength(randomInt(30));
ActionFilters actionFilters = new ActionFilters(filters);
TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(ImmutableSettings.EMPTY, actionName, null, actionFilters) {
@Override
protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) {
listener.onResponse(new TestResponse());
}
};
final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1);
final AtomicInteger responses = new AtomicInteger();
final List<Throwable> failures = new CopyOnWriteArrayList<>();
transportAction.execute(new TestRequest(), new ActionListener<TestResponse>() {
@Override
public void onResponse(TestResponse testResponse) {
responses.incrementAndGet();
latch.countDown();
}
@Override
public void onFailure(Throwable e) {
failures.add(e);
latch.countDown();
}
});
if (!latch.await(10, TimeUnit.SECONDS)) {
fail("timeout waiting for the filter to notify the listener as many times as expected");
}
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));
assertThat(responses.get(), equalTo(1));
assertThat(failures.size(), equalTo(additionalContinueCount));
for (Throwable failure : failures) {
assertThat(failure, instanceOf(IllegalStateException.class));
}
}
private class RequestTestFilter implements ActionFilter {
private final RequestCallback callback;
private final int order; private final int order;
private final Callback callback;
AtomicInteger runs = new AtomicInteger(); AtomicInteger runs = new AtomicInteger();
volatile String lastActionName; volatile String lastActionName;
volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list
TestFilter(int order, Callback callback) { RequestTestFilter(int order, RequestCallback callback) {
this.order = order; this.order = order;
this.callback = callback; this.callback = callback;
} }
@Override
public int order() {
return order;
}
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Override @Override
public void process(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { public void apply(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
this.runs.incrementAndGet(); this.runs.incrementAndGet();
this.lastActionName = action; this.lastActionName = action;
this.executionToken = counter.incrementAndGet(); this.executionToken = counter.incrementAndGet();
@ -205,16 +355,47 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
} }
@Override @Override
public int order() { public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
return order; chain.proceed(action, response, listener);
} }
} }
private static enum Operation implements Callback { private class ResponseTestFilter implements ActionFilter {
private final ResponseCallback callback;
private final int order;
AtomicInteger runs = new AtomicInteger();
volatile String lastActionName;
volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list
ResponseTestFilter(int order, ResponseCallback callback) {
this.order = order;
this.callback = callback;
}
@Override
public int order() {
return order;
}
@Override
public void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
chain.proceed(action, request, listener);
}
@Override
public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
this.runs.incrementAndGet();
this.lastActionName = action;
this.executionToken = counter.incrementAndGet();
this.callback.execute(action, response, listener, chain);
}
}
private static enum RequestOperation implements RequestCallback {
CONTINUE_PROCESSING { CONTINUE_PROCESSING {
@Override @Override
public void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { public void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
actionFilterChain.continueProcessing(action, actionRequest, actionListener); actionFilterChain.proceed(action, actionRequest, actionListener);
} }
}, },
LISTENER_RESPONSE { LISTENER_RESPONSE {
@ -232,10 +413,36 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
} }
} }
private static interface Callback { private static enum ResponseOperation implements ResponseCallback {
CONTINUE_PROCESSING {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
chain.proceed(action, response, listener);
}
},
LISTENER_RESPONSE {
@Override
@SuppressWarnings("unchecked")
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
listener.onResponse(new TestResponse());
}
},
LISTENER_FAILURE {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
listener.onFailure(new ElasticsearchTimeoutException(""));
}
}
}
private static interface RequestCallback {
void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain); void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain);
} }
private static interface ResponseCallback {
void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain);
}
private static class TestRequest extends ActionRequest { private static class TestRequest extends ActionRequest {
@Override @Override
public ActionRequestValidationException validate() { public ActionRequestValidationException validate() {