diff --git a/core/src/main/java/org/elasticsearch/action/support/ActionFilter.java b/core/src/main/java/org/elasticsearch/action/support/ActionFilter.java index 6c08eec323f..d753eda4c69 100644 --- a/core/src/main/java/org/elasticsearch/action/support/ActionFilter.java +++ b/core/src/main/java/org/elasticsearch/action/support/ActionFilter.java @@ -40,13 +40,15 @@ public interface ActionFilter { * Enables filtering the execution of an action on the request side, either by sending a response through the * {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain} */ - void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain); + , Response extends ActionResponse> void apply(Task task, String action, Request request, + ActionListener listener, ActionFilterChain chain); /** * Enables filtering the execution of an action on the response side, either by sending a response through the * {@link ActionListener} or by continuing the execution through the given {@link ActionFilterChain chain} */ - void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain); + void apply(String action, Response response, ActionListener listener, + ActionFilterChain chain); /** * A simple base class for injectable action filters that spares the implementation from handling the @@ -60,7 +62,8 @@ public interface ActionFilter { } @Override - public final void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { + public final , Response extends ActionResponse> void apply(Task task, String action, Request request, + ActionListener listener, ActionFilterChain chain) { if (apply(action, request, listener)) { chain.proceed(task, action, request, listener); } @@ -73,7 +76,8 @@ public interface ActionFilter { protected abstract boolean apply(String action, ActionRequest request, ActionListener listener); @Override - public final void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public final void apply(String action, Response response, ActionListener listener, + ActionFilterChain chain) { if (apply(action, response, listener)) { chain.proceed(action, response, listener); } diff --git a/core/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java b/core/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java index 9b1ae9b2693..54f55e187a9 100644 --- a/core/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java +++ b/core/src/main/java/org/elasticsearch/action/support/ActionFilterChain.java @@ -27,17 +27,17 @@ import org.elasticsearch.tasks.Task; /** * A filter chain allowing to continue and process the transport action request */ -public interface ActionFilterChain { +public interface ActionFilterChain, Response extends ActionResponse> { /** * Continue processing the request. Should only be called if a response has not been sent through * the given {@link ActionListener listener} */ - void proceed(Task task, final String action, final ActionRequest request, final ActionListener listener); + void proceed(Task task, final String action, final Request request, final ActionListener listener); /** * Continue processing the response. Should only be called if a response has not been sent through * the given {@link ActionListener listener} */ - void proceed(final String action, final ActionResponse response, final ActionListener listener); + void proceed(final String action, final Response response, final ActionListener listener); } diff --git a/core/src/main/java/org/elasticsearch/action/support/TransportAction.java b/core/src/main/java/org/elasticsearch/action/support/TransportAction.java index eb62903bf34..584ff14e756 100644 --- a/core/src/main/java/org/elasticsearch/action/support/TransportAction.java +++ b/core/src/main/java/org/elasticsearch/action/support/TransportAction.java @@ -104,7 +104,7 @@ public abstract class TransportAction, Re listener.onFailure(t); } } else { - RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); + RequestFilterChain requestFilterChain = new RequestFilterChain<>(this, logger); requestFilterChain.proceed(task, actionName, request, listener); } } @@ -115,7 +115,8 @@ public abstract class TransportAction, Re protected abstract void doExecute(Request request, ActionListener listener); - private static class RequestFilterChain, Response extends ActionResponse> implements ActionFilterChain { + private static class RequestFilterChain, Response extends ActionResponse> + implements ActionFilterChain { private final TransportAction action; private final AtomicInteger index = new AtomicInteger(); @@ -126,14 +127,15 @@ public abstract class TransportAction, Re this.logger = logger; } - @Override @SuppressWarnings("unchecked") - public void proceed(Task task, String actionName, ActionRequest request, ActionListener listener) { + @Override + public void proceed(Task task, String actionName, Request request, ActionListener listener) { int i = index.getAndIncrement(); try { if (i < this.action.filters.length) { this.action.filters[i].apply(task, actionName, request, listener, this); } else if (i == this.action.filters.length) { - this.action.doExecute(task, (Request) request, new FilteredActionListener(actionName, listener, new ResponseFilterChain(this.action.filters, logger))); + this.action.doExecute(task, request, new FilteredActionListener(actionName, listener, + new ResponseFilterChain<>(this.action.filters, logger))); } else { listener.onFailure(new IllegalStateException("proceed was called too many times")); } @@ -144,12 +146,13 @@ public abstract class TransportAction, Re } @Override - public void proceed(String action, ActionResponse response, ActionListener listener) { + public void proceed(String action, Response response, ActionListener listener) { assert false : "request filter chain should never be called on the response side"; } } - private static class ResponseFilterChain implements ActionFilterChain { + private static class ResponseFilterChain, Response extends ActionResponse> + implements ActionFilterChain { private final ActionFilter[] filters; private final AtomicInteger index; @@ -162,12 +165,12 @@ public abstract class TransportAction, Re } @Override - public void proceed(Task task, String action, ActionRequest request, ActionListener listener) { + public void proceed(Task task, String action, Request request, ActionListener listener) { assert false : "response filter chain should never be called on the request side"; } - @Override @SuppressWarnings("unchecked") - public void proceed(String action, ActionResponse response, ActionListener listener) { + @Override + public void proceed(String action, Response response, ActionListener listener) { int i = index.decrementAndGet(); try { if (i >= 0) { @@ -187,10 +190,10 @@ public abstract class TransportAction, Re private static class FilteredActionListener implements ActionListener { private final String actionName; - private final ActionListener listener; - private final ResponseFilterChain chain; + private final ActionListener listener; + private final ResponseFilterChain chain; - private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) { + private FilteredActionListener(String actionName, ActionListener listener, ResponseFilterChain chain) { this.actionName = actionName; this.listener = listener; this.chain = chain; diff --git a/core/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java b/core/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java index fed4e1d6384..00068c05efe 100644 --- a/core/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/TransportActionFilterChainTests.java @@ -220,9 +220,10 @@ public class TransportActionFilterChainTests extends ESTestCase { RequestTestFilter testFilter = new RequestTestFilter(randomInt(), new RequestCallback() { @Override - public void execute(Task task, final String action, final ActionRequest actionRequest, final ActionListener actionListener, final ActionFilterChain actionFilterChain) { + public , Response extends ActionResponse> void execute(Task task, String action, Request request, + ActionListener listener, ActionFilterChain actionFilterChain) { for (int i = 0; i <= additionalContinueCount; i++) { - actionFilterChain.proceed(task, action, actionRequest, actionListener); + actionFilterChain.proceed(task, action, request, listener); } } }); @@ -276,7 +277,8 @@ public class TransportActionFilterChainTests extends ESTestCase { ResponseTestFilter testFilter = new ResponseTestFilter(randomInt(), new ResponseCallback() { @Override - public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public void execute(String action, Response response, ActionListener listener, + ActionFilterChain chain) { for (int i = 0; i <= additionalContinueCount; i++) { chain.proceed(action, response, listener); } @@ -344,17 +346,18 @@ public class TransportActionFilterChainTests extends ESTestCase { return order; } - @SuppressWarnings("unchecked") @Override - public void apply(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { + public , Response extends ActionResponse> void apply(Task task, String action, Request request, + ActionListener listener, ActionFilterChain chain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); - this.callback.execute(task, action, actionRequest, actionListener, actionFilterChain); + this.callback.execute(task, action, request, listener, chain); } @Override - public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public void apply(String action, Response response, ActionListener listener, + ActionFilterChain chain) { chain.proceed(action, response, listener); } } @@ -377,12 +380,14 @@ public class TransportActionFilterChainTests extends ESTestCase { } @Override - public void apply(Task task, String action, ActionRequest request, ActionListener listener, ActionFilterChain chain) { + public , Response extends ActionResponse> void apply(Task task, String action, Request request, + ActionListener listener, ActionFilterChain chain) { chain.proceed(task, action, request, listener); } @Override - public void apply(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public void apply(String action, Response response, ActionListener listener, + ActionFilterChain chain) { this.runs.incrementAndGet(); this.lastActionName = action; this.executionToken = counter.incrementAndGet(); @@ -393,21 +398,24 @@ public class TransportActionFilterChainTests extends ESTestCase { private static enum RequestOperation implements RequestCallback { CONTINUE_PROCESSING { @Override - public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { - actionFilterChain.proceed(task, action, actionRequest, actionListener); + public , Response extends ActionResponse> void execute(Task task, String action, Request request, + ActionListener listener, ActionFilterChain actionFilterChain) { + actionFilterChain.proceed(task, action, request, listener); } }, LISTENER_RESPONSE { @Override - @SuppressWarnings("unchecked") - public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { - actionListener.onResponse(new TestResponse()); + @SuppressWarnings("unchecked") // Safe because its all we test with + public , Response extends ActionResponse> void execute(Task task, String action, Request request, + ActionListener listener, ActionFilterChain actionFilterChain) { + ((ActionListener) listener).onResponse(new TestResponse()); } }, LISTENER_FAILURE { @Override - public void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain) { - actionListener.onFailure(new ElasticsearchTimeoutException("")); + public , Response extends ActionResponse> void execute(Task task, String action, Request request, + ActionListener listener, ActionFilterChain actionFilterChain) { + listener.onFailure(new ElasticsearchTimeoutException("")); } } } @@ -415,31 +423,36 @@ public class TransportActionFilterChainTests extends ESTestCase { private static enum ResponseOperation implements ResponseCallback { CONTINUE_PROCESSING { @Override - public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public void execute(String action, Response response, ActionListener listener, + ActionFilterChain chain) { chain.proceed(action, response, listener); } }, LISTENER_RESPONSE { @Override - @SuppressWarnings("unchecked") - public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { - listener.onResponse(new TestResponse()); + @SuppressWarnings("unchecked") // Safe because its all we test with + public void execute(String action, Response response, ActionListener listener, + ActionFilterChain chain) { + ((ActionListener) listener).onResponse(new TestResponse()); } }, LISTENER_FAILURE { @Override - public void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain) { + public void execute(String action, Response response, ActionListener listener, + ActionFilterChain chain) { listener.onFailure(new ElasticsearchTimeoutException("")); } } } private static interface RequestCallback { - void execute(Task task, String action, ActionRequest actionRequest, ActionListener actionListener, ActionFilterChain actionFilterChain); + , Response extends ActionResponse> void execute(Task task, String action, Request request, + ActionListener listener, ActionFilterChain actionFilterChain); } private static interface ResponseCallback { - void execute(String action, ActionResponse response, ActionListener listener, ActionFilterChain chain); + void execute(String action, Response response, ActionListener listener, + ActionFilterChain chain); } public static class TestRequest extends ActionRequest { diff --git a/core/src/test/java/org/elasticsearch/cluster/ClusterInfoServiceIT.java b/core/src/test/java/org/elasticsearch/cluster/ClusterInfoServiceIT.java index ac2845c86ab..9a8e8fb7268 100644 --- a/core/src/test/java/org/elasticsearch/cluster/ClusterInfoServiceIT.java +++ b/core/src/test/java/org/elasticsearch/cluster/ClusterInfoServiceIT.java @@ -20,6 +20,7 @@ package org.elasticsearch.cluster; import com.carrotsearch.hppc.cursors.ObjectCursor; + import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionModule; @@ -100,7 +101,7 @@ public class ClusterInfoServiceIT extends ESIntegTestCase { } @Override - protected boolean apply(String action, ActionRequest request, ActionListener listener) { + protected boolean apply(String action, ActionRequest request, ActionListener listener) { if (blockedActions.contains(action)) { throw new ElasticsearchException("force exception on [" + action + "]"); } @@ -108,7 +109,7 @@ public class ClusterInfoServiceIT extends ESIntegTestCase { } @Override - protected boolean apply(String action, ActionResponse response, ActionListener listener) { + protected boolean apply(String action, ActionResponse response, ActionListener listener) { return true; }