Cleanup the raw generics in ActionFilter

Mostly these were pretty easy to clean up by insisting that the request
and response stays consistent across the filter. There are a few places
where we have to make assumptions in tests but those are valid assumptions
for the test.
This commit is contained in:
Nik Everett 2016-01-19 21:53:47 -05:00
parent 3178d24bea
commit 9e3e024358
5 changed files with 66 additions and 45 deletions

View File

@ -40,13 +40,15 @@ public interface ActionFilter {
* Enables filtering the execution of an action on the request side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(Task task, String action, ActionRequest<?> request, ActionListener<?> listener, ActionFilterChain chain);
<Request extends ActionRequest<Request>, Response extends ActionResponse> void apply(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> chain);
/**
* Enables filtering the execution of an action on the response side, either by sending a response through the
* {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain}
*/
void apply(String action, ActionResponse response, ActionListener<?> listener, ActionFilterChain chain);
<Response extends ActionResponse> void apply(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain);
/**
* A simple base class for injectable action filters that spares the implementation from handling the
@ -60,7 +62,8 @@ public interface ActionFilter {
}
@Override
public final void apply(Task task, String action, ActionRequest<?> request, ActionListener<?> listener, ActionFilterChain chain) {
public final <Request extends ActionRequest<Request>, Response extends ActionResponse> void apply(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> chain) {
if (apply(action, request, listener)) {
chain.proceed(task, action, request, listener);
}
@ -73,7 +76,8 @@ public interface ActionFilter {
protected abstract boolean apply(String action, ActionRequest<?> request, ActionListener<?> listener);
@Override
public final void apply(String action, ActionResponse response, ActionListener<?> listener, ActionFilterChain chain) {
public final <Response extends ActionResponse> void apply(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
if (apply(action, response, listener)) {
chain.proceed(action, response, listener);
}

View File

@ -27,17 +27,17 @@ import org.elasticsearch.tasks.Task;
/**
* A filter chain allowing to continue and process the transport action request
*/
public interface ActionFilterChain {
public interface ActionFilterChain<Request extends ActionRequest<Request>, Response extends ActionResponse> {
/**
* Continue processing the request. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/
void proceed(Task task, final String action, final ActionRequest request, final ActionListener listener);
void proceed(Task task, final String action, final Request request, final ActionListener<Response> listener);
/**
* Continue processing the response. Should only be called if a response has not been sent through
* the given {@link ActionListener listener}
*/
void proceed(final String action, final ActionResponse response, final ActionListener listener);
void proceed(final String action, final Response response, final ActionListener<Response> listener);
}

View File

@ -104,7 +104,7 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
listener.onFailure(t);
}
} else {
RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger);
RequestFilterChain<Request, Response> requestFilterChain = new RequestFilterChain<>(this, logger);
requestFilterChain.proceed(task, actionName, request, listener);
}
}
@ -115,7 +115,8 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
protected abstract void doExecute(Request request, ActionListener<Response> listener);
private static class RequestFilterChain<Request extends ActionRequest<Request>, Response extends ActionResponse> implements ActionFilterChain {
private static class RequestFilterChain<Request extends ActionRequest<Request>, Response extends ActionResponse>
implements ActionFilterChain<Request, Response> {
private final TransportAction<Request, Response> action;
private final AtomicInteger index = new AtomicInteger();
@ -126,14 +127,15 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
this.logger = logger;
}
@Override @SuppressWarnings("unchecked")
public void proceed(Task task, String actionName, ActionRequest request, ActionListener listener) {
@Override
public void proceed(Task task, String actionName, Request request, ActionListener<Response> listener) {
int i = index.getAndIncrement();
try {
if (i < this.action.filters.length) {
this.action.filters[i].apply(task, actionName, request, listener, this);
} else if (i == this.action.filters.length) {
this.action.doExecute(task, (Request) request, new FilteredActionListener<Response>(actionName, listener, new ResponseFilterChain(this.action.filters, logger)));
this.action.doExecute(task, request, new FilteredActionListener<Response>(actionName, listener,
new ResponseFilterChain<>(this.action.filters, logger)));
} else {
listener.onFailure(new IllegalStateException("proceed was called too many times"));
}
@ -144,12 +146,13 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
}
@Override
public void proceed(String action, ActionResponse response, ActionListener listener) {
public void proceed(String action, Response response, ActionListener<Response> listener) {
assert false : "request filter chain should never be called on the response side";
}
}
private static class ResponseFilterChain implements ActionFilterChain {
private static class ResponseFilterChain<Request extends ActionRequest<Request>, Response extends ActionResponse>
implements ActionFilterChain<Request, Response> {
private final ActionFilter[] filters;
private final AtomicInteger index;
@ -162,12 +165,12 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
}
@Override
public void proceed(Task task, String action, ActionRequest request, ActionListener listener) {
public void proceed(Task task, String action, Request request, ActionListener<Response> listener) {
assert false : "response filter chain should never be called on the request side";
}
@Override @SuppressWarnings("unchecked")
public void proceed(String action, ActionResponse response, ActionListener listener) {
@Override
public void proceed(String action, Response response, ActionListener<Response> listener) {
int i = index.decrementAndGet();
try {
if (i >= 0) {
@ -187,10 +190,10 @@ public abstract class TransportAction<Request extends ActionRequest<Request>, Re
private static class FilteredActionListener<Response extends ActionResponse> implements ActionListener<Response> {
private final String actionName;
private final ActionListener listener;
private final ResponseFilterChain chain;
private final ActionListener<Response> listener;
private final ResponseFilterChain<?, Response> chain;
private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) {
private FilteredActionListener(String actionName, ActionListener<Response> listener, ResponseFilterChain<?, Response> chain) {
this.actionName = actionName;
this.listener = listener;
this.chain = chain;

View File

@ -220,9 +220,10 @@ public class TransportActionFilterChainTests extends ESTestCase {
RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() {
@Override
public void execute(Task task, final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) {
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void execute(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) {
for (int i = 0; i <= additionalContinueCount; i++) {
actionFilterChain.proceed(task, action, actionRequest, actionListener);
actionFilterChain.proceed(task, action, request, listener);
}
}
});
@ -276,7 +277,8 @@ public class TransportActionFilterChainTests extends ESTestCase {
ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
public <Response extends ActionResponse> void execute(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
for (int i = 0; i <= additionalContinueCount; i++) {
chain.proceed(action, response, listener);
}
@ -344,17 +346,18 @@ public class TransportActionFilterChainTests extends ESTestCase {
return order;
}
@SuppressWarnings("unchecked")
@Override
public void apply(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void apply(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> chain) {
this.runs.incrementAndGet();
this.lastActionName = action;
this.executionToken = counter.incrementAndGet();
this.callback.execute(task, action, actionRequest, actionListener, actionFilterChain);
this.callback.execute(task, action, request, listener, chain);
}
@Override
public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
public <Response extends ActionResponse> void apply(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
chain.proceed(action, response, listener);
}
}
@ -377,12 +380,14 @@ public class TransportActionFilterChainTests extends ESTestCase {
}
@Override
public void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) {
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void apply(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> chain) {
chain.proceed(task, action, request, listener);
}
@Override
public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
public <Response extends ActionResponse> void apply(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
this.runs.incrementAndGet();
this.lastActionName = action;
this.executionToken = counter.incrementAndGet();
@ -393,21 +398,24 @@ public class TransportActionFilterChainTests extends ESTestCase {
private static enum RequestOperation implements RequestCallback {
CONTINUE_PROCESSING {
@Override
public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
actionFilterChain.proceed(task, action, actionRequest, actionListener);
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void execute(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) {
actionFilterChain.proceed(task, action, request, listener);
}
},
LISTENER_RESPONSE {
@Override
@SuppressWarnings("unchecked")
public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
actionListener.onResponse(new TestResponse());
@SuppressWarnings("unchecked") // Safe because its all we test with
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void execute(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) {
((ActionListener<TestResponse>) listener).onResponse(new TestResponse());
}
},
LISTENER_FAILURE {
@Override
public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) {
actionListener.onFailure(new ElasticsearchTimeoutException(""));
public <Request extends ActionRequest<Request>, Response extends ActionResponse> void execute(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain) {
listener.onFailure(new ElasticsearchTimeoutException(""));
}
}
}
@ -415,31 +423,36 @@ public class TransportActionFilterChainTests extends ESTestCase {
private static enum ResponseOperation implements ResponseCallback {
CONTINUE_PROCESSING {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
public <Response extends ActionResponse> void execute(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
chain.proceed(action, response, listener);
}
},
LISTENER_RESPONSE {
@Override
@SuppressWarnings("unchecked")
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
listener.onResponse(new TestResponse());
@SuppressWarnings("unchecked") // Safe because its all we test with
public <Response extends ActionResponse> void execute(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
((ActionListener<TestResponse>) listener).onResponse(new TestResponse());
}
},
LISTENER_FAILURE {
@Override
public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) {
public <Response extends ActionResponse> void execute(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain) {
listener.onFailure(new ElasticsearchTimeoutException(""));
}
}
}
private static interface RequestCallback {
void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain);
<Request extends ActionRequest<Request>, Response extends ActionResponse> void execute(Task task, String action, Request request,
ActionListener<Response> listener, ActionFilterChain<Request, Response> actionFilterChain);
}
private static interface ResponseCallback {
void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain);
<Response extends ActionResponse> void execute(String action, Response response, ActionListener<Response> listener,
ActionFilterChain<?, Response> chain);
}
public static class TestRequest extends ActionRequest<TestRequest> {

View File

@ -20,6 +20,7 @@
package org.elasticsearch.cluster;
import com.carrotsearch.hppc.cursors.ObjectCursor;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionModule;
@ -100,7 +101,7 @@ public class ClusterInfoServiceIT extends ESIntegTestCase {
}
@Override
protected boolean apply(String action, ActionRequest request, ActionListener listener) {
protected boolean apply(String action, ActionRequest<?> request, ActionListener<?> listener) {
if (blockedActions.contains(action)) {
throw new ElasticsearchException("force exception on [" + action + "]");
}
@ -108,7 +109,7 @@ public class ClusterInfoServiceIT extends ESIntegTestCase {
}
@Override
protected boolean apply(String action, ActionResponse response, ActionListener listener) {
protected boolean apply(String action, ActionResponse response, ActionListener<?> listener) {
return true;
}