Extended ActionFilter to also enable filtering the response side

Enables filtering the actions on both sides - request and response. Also added a base class for filter implementations (cleans up filters that only need to filter one side)

Also refactored the filter & filter chain methods to more intuitive names
This commit is contained in:
uboness 2014-08-26 10:39:45 -07:00
parent dd54025b17
commit 333a39cf30
4 changed files with 378 additions and 49 deletions

View File

@ -21,20 +21,67 @@ package org.elasticsearch.action.support;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.settings.Settings;
/**
* A filter allowing to filter transport actions
*/
public interface ActionFilter {
/**
* Filters the actual execution of the request by either sending a response through the {@link ActionListener}
* or continuing the filters execution through the {@link ActionFilterChain}
*/
void process(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain);
/**
* The position of the filter in the chain. Execution is done from lowest order to highest.
*/
int order();
/**
* Enables filtering the execution of an action on the request side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain);
/**
* Enables filtering the execution of an action on the response side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain);
/**
* A simple base class for injectable action filters that spares the implementation from handling the
* filter chain. This base class should serve any action filter implementations that doesn't require
* to apply async filtering logic.
*/
public static abstract class Simple extends AbstractComponent implements ActionFilter {
protected Simple(Settings settings) {
super(settings);
}
@Override
public final void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
if (apply(action, request, listener)) {
chain.proceed(action, request, listener);
}
}
/**
* Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false}
* if it should be aborted since the filter already handled the request and called the given listener.
*/
protected abstract boolean apply(String action, ActionRequest request, ActionListener listener);
@Override
public final void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
if (apply(action, response, listener)) {
chain.proceed(action, response, listener);
}
}
/**
* Applies this filter and returns {@code true} if the execution chain should proceed, or {@code false}
* if it should be aborted since the filter already handled the response by calling the given listener.
*/
protected abstract boolean apply(String action, ActionResponse response, ActionListener listener);
}
}

View File

@ -21,6 +21,7 @@ package org.elasticsearch.action.support;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
/**
* A filter chain allowing to continue and process the transport action request
@ -28,7 +29,14 @@ import org.elasticsearch.action.ActionRequest;
public interface ActionFilterChain {
/**
* Continue processing the request. Should only be called if a response has not been sent through the {@link ActionListener}
* Continue processing the request. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/
void continueProcessing(final String action, final ActionRequest request, final ActionListener actionListener);
void proceed(final String action, final ActionRequest request, final ActionListener listener);
/**
* Continue processing the response. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/
void proceed(final String action, final ActionResponse response, final ActionListener listener);
}

View File

@ -78,8 +78,8 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
listener.onFailure(t);
}
} else {
ActionFilterChain actionFilterChain = new TransportActionFilterChain();
actionFilterChain.continueProcessing(actionName, request, listener);
RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger);
requestFilterChain.proceed(actionName, request, listener);
}
}
@ -146,28 +146,95 @@ public abstract class TransportAction<Request extends ActionRequest, Response ex
}
}
private class TransportActionFilterChain implements ActionFilterChain {
private static class RequestFilterChain<Request extends ActionRequest, Response extends ActionResponse> implements ActionFilterChain {
private final TransportAction<Request, Response> action;
private final AtomicInteger index = new AtomicInteger();
private final ESLogger logger;
@SuppressWarnings("unchecked")
@Override
public void continueProcessing(String action, ActionRequest actionRequest, ActionListener actionListener) {
private RequestFilterChain(TransportAction<Request, Response> action, ESLogger logger) {
this.action = action;
this.logger = logger;
}
@Override @SuppressWarnings("unchecked")
public void proceed(String actionName, ActionRequest request, ActionListener listener) {
int i = index.getAndIncrement();
try {
if (i < filters.length) {
filters[i].process(action, actionRequest, actionListener, this);
} else if (i == filters.length) {
ActionListener<Response> listener = (ActionListener<Response>) actionListener;
Request request = (Request) actionRequest;
doExecute(request, listener);
if (i < this.action.filters.length) {
this.action.filters[i].apply(actionName, request, listener, this);
} else if (i == this.action.filters.length) {
this.action.doExecute((Request) request, new FilteredActionListener<Response>(actionName, listener, new ResponseFilterChain(this.action.filters, logger)));
} else {
actionListener.onFailure(new IllegalStateException("continueProcessing was called too many times"));
listener.onFailure(new IllegalStateException("proceed was called too many times"));
}
} catch(Throwable t) {
logger.trace("Error during transport action execution.", t);
actionListener.onFailure(t);
listener.onFailure(t);
}
}
@Override
public void proceed(String action, ActionResponse response, ActionListener listener) {
assert false : "request filter chain should never be called on the response side";
}
}
private static class ResponseFilterChain implements ActionFilterChain {
private final ActionFilter[] filters;
private final AtomicInteger index;
private final ESLogger logger;
private ResponseFilterChain(ActionFilter[] filters, ESLogger logger) {
this.filters = filters;
this.index = new AtomicInteger(filters.length);
this.logger = logger;
}
@Override
public void proceed(String action, ActionRequest request, ActionListener listener) {
assert false : "response filter chain should never be called on the request side";
}
@Override @SuppressWarnings("unchecked")
public void proceed(String action, ActionResponse response, ActionListener listener) {
int i = index.decrementAndGet();
try {
if (i >= 0) {
filters[i].apply(action, response, listener, this);
} else if (i == -1) {
listener.onResponse(response);
} else {
listener.onFailure(new IllegalStateException("proceed was called too many times"));
}
} catch (Throwable t) {
logger.trace("Error during transport action execution.", t);
listener.onFailure(t);
}
}
}
private static class FilteredActionListener<Response extends ActionResponse> implements ActionListener<Response> {
private final String actionName;
private final ActionListener listener;
private final ResponseFilterChain chain;
private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) {
this.actionName = actionName;
this.listener = listener;
this.chain = chain;
}
@Override
public void onResponse(Response response) {
chain.proceed(actionName, response, listener);
}
@Override
public void onFailure(Throwable e) {
listener.onFailure(e);
}
}
}

View File

@ -27,6 +27,7 @@ import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.settings.ImmutableSettings;
import org.elasticsearch.test.ElasticsearchTestCase;
import org.junit.Before;
import org.junit.Test;
import java.util.*;
@ -40,8 +41,15 @@ import static org.hamcrest.CoreMatchers.*;
public class TransportActionFilterChainTests extends ElasticsearchTestCase {
private AtomicInteger counter;
@Before
public void init() throws Exception {
counter = new AtomicInteger();
}
@Test
public void testActionFilters() throws ExecutionException, InterruptedException {
public void testActionFiltersRequest() throws ExecutionException, InterruptedException {
int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters);
@ -51,7 +59,7 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
Set<ActionFilter> filters = new HashSet<>();
for (Integer order : orders) {
filters.add(new TestFilter(order, randomFrom(Operation.values())));
filters.add(new RequestTestFilter(order, randomFrom(RequestOperation.values())));
}
String actionName = randomAsciiOfLength(randomInt(30));
@ -74,12 +82,12 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
List<ActionFilter> expectedActionFilters = Lists.newArrayList();
boolean errorExpected = false;
for (ActionFilter filter : actionFiltersByOrder) {
TestFilter testFilter = (TestFilter) filter;
RequestTestFilter testFilter = (RequestTestFilter) filter;
expectedActionFilters.add(testFilter);
if (testFilter.callback == Operation.LISTENER_FAILURE) {
if (testFilter.callback == RequestOperation.LISTENER_FAILURE) {
errorExpected = true;
}
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) {
if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) {
break;
}
}
@ -93,29 +101,29 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true));
}
List<TestFilter> testFiltersByLastExecution = Lists.newArrayList();
List<RequestTestFilter> testFiltersByLastExecution = Lists.newArrayList();
for (ActionFilter actionFilter : actionFilters.filters()) {
testFiltersByLastExecution.add((TestFilter) actionFilter);
testFiltersByLastExecution.add((RequestTestFilter) actionFilter);
}
Collections.sort(testFiltersByLastExecution, new Comparator<TestFilter>() {
Collections.sort(testFiltersByLastExecution, new Comparator<RequestTestFilter>() {
@Override
public int compare(TestFilter o1, TestFilter o2) {
public int compare(RequestTestFilter o1, RequestTestFilter o2) {
return Integer.compare(o1.executionToken, o2.executionToken);
}
});
ArrayList<TestFilter> finalTestFilters = Lists.newArrayList();
ArrayList<RequestTestFilter> finalTestFilters = Lists.newArrayList();
for (ActionFilter filter : testFiltersByLastExecution) {
TestFilter testFilter = (TestFilter) filter;
RequestTestFilter testFilter = (RequestTestFilter) filter;
finalTestFilters.add(testFilter);
if (!(testFilter.callback == Operation.CONTINUE_PROCESSING) ) {
if (!(testFilter.callback == RequestOperation.CONTINUE_PROCESSING) ) {
break;
}
}
assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size()));
for (int i = 0; i < finalTestFilters.size(); i++) {
TestFilter testFilter = finalTestFilters.get(i);
RequestTestFilter testFilter = finalTestFilters.get(i);
assertThat(testFilter, equalTo(expectedActionFilters.get(i)));
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));
@ -123,15 +131,97 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
}
@Test
public void testTooManyContinueProcessing() throws ExecutionException, InterruptedException {
public void testActionFiltersResponse() throws ExecutionException, InterruptedException {
int numFilters = randomInt(10);
Set<Integer> orders = new HashSet<>(numFilters);
while (orders.size() < numFilters) {
orders.add(randomInt(10));
}
Set<ActionFilter> filters = new HashSet<>();
for (Integer order : orders) {
filters.add(new ResponseTestFilter(order, randomFrom(ResponseOperation.values())));
}
String actionName = randomAsciiOfLength(randomInt(30));
ActionFilters actionFilters = new ActionFilters(filters);
TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(ImmutableSettings.EMPTY, actionName, null, actionFilters) {
@Override
protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) {
listener.onResponse(new TestResponse());
}
};
ArrayList<ActionFilter> actionFiltersByOrder = Lists.newArrayList(filters);
Collections.sort(actionFiltersByOrder, new Comparator<ActionFilter>() {
@Override
public int compare(ActionFilter o1, ActionFilter o2) {
return Integer.compare(o2.order(), o1.order());
}
});
List<ActionFilter> expectedActionFilters = Lists.newArrayList();
boolean errorExpected = false;
for (ActionFilter filter : actionFiltersByOrder) {
ResponseTestFilter testFilter = (ResponseTestFilter) filter;
expectedActionFilters.add(testFilter);
if (testFilter.callback == ResponseOperation.LISTENER_FAILURE) {
errorExpected = true;
}
if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) {
break;
}
}
PlainListenableActionFuture<TestResponse> future = new PlainListenableActionFuture<>(false, null);
transportAction.execute(new TestRequest(), future);
try {
assertThat(future.get(), notNullValue());
assertThat("shouldn't get here if an error is expected", errorExpected, equalTo(false));
} catch(Throwable t) {
assertThat("shouldn't get here if an error is not expected " + t.getMessage(), errorExpected, equalTo(true));
}
List<ResponseTestFilter> testFiltersByLastExecution = Lists.newArrayList();
for (ActionFilter actionFilter : actionFilters.filters()) {
testFiltersByLastExecution.add((ResponseTestFilter) actionFilter);
}
Collections.sort(testFiltersByLastExecution, new Comparator<ResponseTestFilter>() {
@Override
public int compare(ResponseTestFilter o1, ResponseTestFilter o2) {
return Integer.compare(o1.executionToken, o2.executionToken);
}
});
ArrayList<ResponseTestFilter> finalTestFilters = Lists.newArrayList();
for (ActionFilter filter : testFiltersByLastExecution) {
ResponseTestFilter testFilter = (ResponseTestFilter) filter;
finalTestFilters.add(testFilter);
if (testFilter.callback != ResponseOperation.CONTINUE_PROCESSING) {
break;
}
}
assertThat(finalTestFilters.size(), equalTo(expectedActionFilters.size()));
for (int i = 0; i < finalTestFilters.size(); i++) {
ResponseTestFilter testFilter = finalTestFilters.get(i);
assertThat(testFilter, equalTo(expectedActionFilters.get(i)));
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));
}
}
@Test
public void testTooManyContinueProcessingRequest() throws ExecutionException, InterruptedException {
final int additionalContinueCount = randomInt(10);
TestFilter testFilter = new TestFilter(randomInt(), new Callback() {
RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() {
@Override
public void execute(final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) {
for (int i = 0; i <= additionalContinueCount; i++) {
actionFilterChain.continueProcessing(action, actionRequest, actionListener);
actionFilterChain.proceed(action, actionRequest, actionListener);
}
}
});
@ -180,24 +270,84 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
}
}
private final AtomicInteger counter = new AtomicInteger();
@Test
public void testTooManyContinueProcessingResponse() throws ExecutionException, InterruptedException {
private class TestFilter implements ActionFilter {
final int additionalContinueCount = randomInt(10);
ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
for (int i = 0; i <= additionalContinueCount; i++) {
chain.proceed(action, response, listener);
}
}
});
Set<ActionFilter> filters = new HashSet<>();
filters.add(testFilter);
String actionName = randomAsciiOfLength(randomInt(30));
ActionFilters actionFilters = new ActionFilters(filters);
TransportAction<TestRequest, TestResponse> transportAction = new TransportAction<TestRequest, TestResponse>(ImmutableSettings.EMPTY, actionName, null, actionFilters) {
@Override
protected void doExecute(TestRequest request, ActionListener<TestResponse> listener) {
listener.onResponse(new TestResponse());
}
};
final CountDownLatch latch = new CountDownLatch(additionalContinueCount + 1);
final AtomicInteger responses = new AtomicInteger();
final List<Throwable> failures = new CopyOnWriteArrayList<>();
transportAction.execute(new TestRequest(), new ActionListener<TestResponse>() {
@Override
public void onResponse(TestResponse testResponse) {
responses.incrementAndGet();
latch.countDown();
}
@Override
public void onFailure(Throwable e) {
failures.add(e);
latch.countDown();
}
});
if (!latch.await(10, TimeUnit.SECONDS)) {
fail("timeout waiting for the filter to notify the listener as many times as expected");
}
assertThat(testFilter.runs.get(), equalTo(1));
assertThat(testFilter.lastActionName, equalTo(actionName));
assertThat(responses.get(), equalTo(1));
assertThat(failures.size(), equalTo(additionalContinueCount));
for (Throwable failure : failures) {
assertThat(failure, instanceOf(IllegalStateException.class));
}
}
private class RequestTestFilter implements ActionFilter {
private final RequestCallback callback;
private final int order;
private final Callback callback;
AtomicInteger runs = new AtomicInteger();
volatile String lastActionName;
volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list
TestFilter(int order, Callback callback) {
RequestTestFilter(int order, RequestCallback callback) {
this.order = order;
this.callback = callback;
}
@Override
public int order() {
return order;
}
@SuppressWarnings("unchecked")
@Override
public void process(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
public void apply(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
this.runs.incrementAndGet();
this.lastActionName = action;
this.executionToken = counter.incrementAndGet();
@ -205,16 +355,47 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
}
@Override
public int order() {
return order;
public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
chain.proceed(action, response, listener);
}
}
private static enum Operation implements Callback {
private class ResponseTestFilter implements ActionFilter {
private final ResponseCallback callback;
private final int order;
AtomicInteger runs = new AtomicInteger();
volatile String lastActionName;
volatile int executionToken = Integer.MAX_VALUE; //the filters that don't run will go last in the sorted list
ResponseTestFilter(int order, ResponseCallback callback) {
this.order = order;
this.callback = callback;
}
@Override
public int order() {
return order;
}
@Override
public void apply(String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
chain.proceed(action, request, listener);
}
@Override
public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
this.runs.incrementAndGet();
this.lastActionName = action;
this.executionToken = counter.incrementAndGet();
this.callback.execute(action, response, listener, chain);
}
}
private static enum RequestOperation implements RequestCallback {
CONTINUE_PROCESSING {
@Override
public void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
actionFilterChain.continueProcessing(action, actionRequest, actionListener);
actionFilterChain.proceed(action, actionRequest, actionListener);
}
},
LISTENER_RESPONSE {
@ -232,10 +413,36 @@ public class TransportActionFilterChainTests extends ElasticsearchTestCase {
}
}
private static interface Callback {
private static enum ResponseOperation implements ResponseCallback {
CONTINUE_PROCESSING {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
chain.proceed(action, response, listener);
}
},
LISTENER_RESPONSE {
@Override
@SuppressWarnings("unchecked")
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
listener.onResponse(new TestResponse());
}
},
LISTENER_FAILURE {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
listener.onFailure(new ElasticsearchTimeoutException(""));
}
}
}
private static interface RequestCallback {
void execute(String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain);
}
private static interface ResponseCallback {
void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain);
}
private static class TestRequest extends ActionRequest {
@Override
public ActionRequestValidationException validate() {