diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/Security.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/Security.java index ebd26d98995..bcd2c6ec949 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -704,9 +704,10 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin { return Collections.singletonList(new TransportInterceptor() { @Override public TransportRequestHandler interceptHandler(String action, String executor, + boolean forceExecution, TransportRequestHandler actualHandler) { assert securityInterceptor.get() != null; - return securityInterceptor.get().interceptHandler(action, executor, actualHandler); + return securityInterceptor.get().interceptHandler(action, executor, forceExecution, actualHandler); } @Override diff --git a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java index 2ec8b6015b0..ddbc803aef5 100644 --- a/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java +++ b/elasticsearch/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java @@ -10,6 +10,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.DestructiveOperations; import org.elasticsearch.common.CheckedConsumer; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.tasks.Task; @@ -43,7 +44,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.concurrent.Executor; -import java.util.function.Consumer; import static org.elasticsearch.xpack.XPackSettings.TRANSPORT_SSL_ENABLED; import static org.elasticsearch.xpack.security.Security.setting; @@ -129,8 +129,9 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor @Override public TransportRequestHandler interceptHandler(String action, String executor, + boolean forceExecution, TransportRequestHandler actualHandler) { - return new ProfileSecuredRequestHandler<>(action, executor, actualHandler, profileFilters, + return new ProfileSecuredRequestHandler<>(action, forceExecution, executor, actualHandler, profileFilters, licenseState, threadPool); } @@ -179,10 +180,11 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor private final ThreadContext threadContext; private final String executorName; private final ThreadPool threadPool; + private final boolean forceExecution; - private ProfileSecuredRequestHandler(String action, String executorName, TransportRequestHandler handler, - Map profileFilters, XPackLicenseState licenseState, - ThreadPool threadPool) { + ProfileSecuredRequestHandler(String action, boolean forceExecution, String executorName, TransportRequestHandler handler, + Map profileFilters, XPackLicenseState licenseState, + ThreadPool threadPool) { this.action = action; this.executorName = executorName; this.handler = handler; @@ -190,30 +192,52 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor this.licenseState = licenseState; this.threadContext = threadPool.getThreadContext(); this.threadPool = threadPool; + this.forceExecution = forceExecution; + } + + AbstractRunnable getReceiveRunnable(T request, TransportChannel channel, Task task) { + return new AbstractRunnable() { + @Override + public boolean isForceExecution() { + return forceExecution; + } + + @Override + public void onFailure(Exception e) { + try { + channel.sendResponse(e); + } catch (IOException e1) { + throw new UncheckedIOException(e1); + } + } + + @Override + protected void doRun() throws Exception { + // FIXME we should remove the RequestContext completely since we have ThreadContext but cannot yet due to + // the query cache + RequestContext context = new RequestContext(request, threadContext); + RequestContext.setCurrent(context); + try { + handler.messageReceived(request, channel, task); + } finally { + RequestContext.removeCurrent(); + } + } + }; + } + + @Override + public String toString() { + return "ProfileSecuredRequestHandler{" + + "action='" + action + '\'' + + ", executorName='" + executorName + '\'' + + ", forceExecution=" + forceExecution + + '}'; } @Override public void messageReceived(T request, TransportChannel channel, Task task) throws Exception { - final Consumer onFailure = (e) -> { - try { - channel.sendResponse(e); - } catch (IOException e1) { - throw new UncheckedIOException(e1); - } - }; - final Runnable receiveMessage = () -> { - // FIXME we should remove the RequestContext completely since we have ThreadContext but cannot yet due to - // the query cache - RequestContext context = new RequestContext(request, threadContext); - RequestContext.setCurrent(context); - try { - handler.messageReceived(request, channel, task); - } catch (Exception e) { - onFailure.accept(e); - } finally { - RequestContext.removeCurrent(); - } - }; + final AbstractRunnable receiveMessage = getReceiveRunnable(request, channel, task); try (ThreadContext.StoredContext ctx = threadContext.newStoredContext(true)) { if (licenseState.isAuthAllowed()) { String profile = channel.getProfileName(); @@ -248,17 +272,15 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor try { executor.execute(receiveMessage); } catch (Exception e) { - onFailure.accept(e); + receiveMessage.onFailure(e); } }; - ActionListener filterListener = ActionListener.wrap(consumer, onFailure); + ActionListener filterListener = ActionListener.wrap(consumer, receiveMessage::onFailure); filter.inbound(action, request, channel, filterListener); } else { receiveMessage.run(); } - } catch (Exception e) { - channel.sendResponse(e); } } diff --git a/elasticsearch/src/test/java/org/elasticsearch/transport/SecurityServerTransportServiceTests.java b/elasticsearch/src/test/java/org/elasticsearch/transport/SecurityServerTransportServiceTests.java index 27616a91c03..676ae324552 100644 --- a/elasticsearch/src/test/java/org/elasticsearch/transport/SecurityServerTransportServiceTests.java +++ b/elasticsearch/src/test/java/org/elasticsearch/transport/SecurityServerTransportServiceTests.java @@ -38,12 +38,13 @@ public class SecurityServerTransportServiceTests extends SecurityIntegTestCase { public void testSecurityServerTransportServiceWrapsAllHandlers() { for (TransportService transportService : internalCluster().getInstances(TransportService.class)) { for (Map.Entry entry : transportService.requestHandlers.entrySet()) { - assertThat( + RequestHandlerRegistry handler = entry.getValue(); + assertEquals( "handler not wrapped by " + SecurityServerTransportInterceptor.ProfileSecuredRequestHandler.class + "; do all the handler registration methods have overrides?", - entry.getValue().toString(), - startsWith(SecurityServerTransportInterceptor.ProfileSecuredRequestHandler.class.getName() + "@") - ); + handler.toString(), + "ProfileSecuredRequestHandler{action='" + handler.getAction() + "', executorName='" + handler.getExecutor() + + "', forceExecution=" + handler.isForceExecution() + "}"); } } }