security: restore the correct user when switching to the system user

* security: restore the correct user when switching to the system user

For internal actions where we need to switch to the SystemUser, we should always restore the proper
context after execution. We were restoring an empty context for actions executed by the SystemUser
in the SecurityServerTransportInterceptor.

In order to accomplish this, a few changes have been made. Both the SecurityServerTransportInterceptor
and the SecurityActionFilter delegate to `SecurityContext#executeAsUser` when a user switch is necessary.
Tests were added for this method to ensure that the consumer is executed as the correct user and the proper
user is restored.

While working on this, a few other cleanups were made:

* SecurityContext can never have a null CryptoService, so a null check was removed
* We no longer replace the user with the system user when the system user is already associated with the request
* The security transport interceptor checks the license state and if auth is not allowed, delegate and return
* The security transport interceptor sendWithUser method now requires authentication to be present or a hard
exception is thrown.
* The TransportFilters integration test has been deleted. This was integration test that relied on the ability to
get instances from a node and trace the execution. This has been replaced by additional unit tests in
ServerTransportFilterTests

Closes elastic/elasticsearch#3845

Original commit: elastic/x-pack-elasticsearch@d8bcb59cb7
This commit is contained in:
Jay Modi 2016-10-25 13:48:28 -04:00 committed by GitHub
parent a50bc7946b
commit 7d60f6b365
13 changed files with 494 additions and 436 deletions

View File

@ -15,11 +15,7 @@ import org.elasticsearch.common.logging.LoggerMessageFormat;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.license.License.OperationMode; import org.elasticsearch.license.License.OperationMode;
import org.elasticsearch.xpack.XPackPlugin; import org.elasticsearch.xpack.XPackPlugin;
import org.elasticsearch.xpack.graph.Graph;
import org.elasticsearch.xpack.monitoring.Monitoring;
import org.elasticsearch.xpack.monitoring.MonitoringSettings; import org.elasticsearch.xpack.monitoring.MonitoringSettings;
import org.elasticsearch.xpack.security.Security;
import org.elasticsearch.xpack.watcher.Watcher;
/** /**
* A holder for the current state of the license for all xpack features. * A holder for the current state of the license for all xpack features.

View File

@ -334,7 +334,7 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
ipFilter.set(new IPFilter(settings, auditTrailService, clusterService.getClusterSettings(), licenseState)); ipFilter.set(new IPFilter(settings, auditTrailService, clusterService.getClusterSettings(), licenseState));
components.add(ipFilter.get()); components.add(ipFilter.get());
securityIntercepter.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService, authzService, licenseState, securityIntercepter.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService, authzService, licenseState,
sslService)); sslService, securityContext));
return components; return components;
} }

View File

@ -9,6 +9,8 @@ import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext;
import org.elasticsearch.node.Node;
import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.AuthenticationService;
@ -16,6 +18,8 @@ import org.elasticsearch.xpack.security.crypto.CryptoService;
import org.elasticsearch.xpack.security.user.User; import org.elasticsearch.xpack.security.user.User;
import java.io.IOException; import java.io.IOException;
import java.util.Objects;
import java.util.function.Consumer;
/** /**
* A lightweight utility that can find the current user and authentication information for the local thread. * A lightweight utility that can find the current user and authentication information for the local thread.
@ -26,6 +30,7 @@ public class SecurityContext {
private final ThreadContext threadContext; private final ThreadContext threadContext;
private final CryptoService cryptoService; private final CryptoService cryptoService;
private final boolean signUserHeader; private final boolean signUserHeader;
private final String nodeName;
/** /**
* Creates a new security context. * Creates a new security context.
@ -37,6 +42,7 @@ public class SecurityContext {
this.threadContext = threadPool.getThreadContext(); this.threadContext = threadPool.getThreadContext();
this.cryptoService = cryptoService; this.cryptoService = cryptoService;
this.signUserHeader = AuthenticationService.SIGN_USER_HEADER.get(settings); this.signUserHeader = AuthenticationService.SIGN_USER_HEADER.get(settings);
this.nodeName = Node.NODE_NAME_SETTING.get(settings);
} }
/** Returns the current user information, or null if the current request has no authentication info. */ /** Returns the current user information, or null if the current request has no authentication info. */
@ -47,9 +53,6 @@ public class SecurityContext {
/** Returns the authentication information, or null if the current request has no authentication info. */ /** Returns the authentication information, or null if the current request has no authentication info. */
public Authentication getAuthentication() { public Authentication getAuthentication() {
if (cryptoService == null) {
return null;
}
try { try {
return Authentication.readFromContext(threadContext, cryptoService, signUserHeader); return Authentication.readFromContext(threadContext, cryptoService, signUserHeader);
} catch (IOException e) { } catch (IOException e) {
@ -59,4 +62,38 @@ public class SecurityContext {
return null; return null;
} }
} }
/**
* Sets the user forcefully to the provided user. There must not be an existing user in the ThreadContext otherwise an exception
* will be thrown. This method is package private for testing.
*/
void setUser(User user) {
Objects.requireNonNull(user);
final Authentication.RealmRef lookedUpBy;
if (user.runAs() == null) {
lookedUpBy = null;
} else {
lookedUpBy = new Authentication.RealmRef("__attach", "__attach", nodeName);
}
try {
Authentication authentication =
new Authentication(user, new Authentication.RealmRef("__attach", "__attach", nodeName), lookedUpBy);
authentication.writeToContext(threadContext, cryptoService, signUserHeader);
} catch (IOException e) {
throw new AssertionError("how can we have a IOException with a user we set", e);
}
}
/**
* Runs the consumer in a new context as the provided user. The original constext is provided to the consumer. When this method
* returns, the original context is restored.
*/
public void executeAsUser(User user, Consumer<StoredContext> consumer) {
final StoredContext original = threadContext.newStoredContext();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
setUser(user);
consumer.accept(original);
}
}
} }

View File

@ -93,45 +93,51 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
throw LicenseUtils.newComplianceException(XPackPlugin.SECURITY); throw LicenseUtils.newComplianceException(XPackPlugin.SECURITY);
} }
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user if (licenseState.isAuthAllowed() == false) {
// and then a cleanup action is executed (like for search without a scroll) if (SECURITY_ACTION_MATCHER.test(action)) {
final ThreadContext.StoredContext original = threadContext.newStoredContext(); // TODO we should be nice and just call the listener
final boolean restoreOriginalContext = securityContext.getAuthentication() != null; listener.onFailure(LicenseUtils.newComplianceException(XPackPlugin.SECURITY));
try {
if (licenseState.isAuthAllowed()) {
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
// we should always restore the original here because we forcefully changed to the system user
final ThreadContext.StoredContext toRestore = restoreOriginalContext || useSystemUser ? original : () -> {};
final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(toRestore,
ActionListener.wrap(r -> {
try {
listener.onResponse(sign(r));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}, listener::onFailure));
ActionListener<Void> authenticatedListener = new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
chain.proceed(task, action, request, signingListener);
}
@Override
public void onFailure(Exception e) {
signingListener.onFailure(e);
}
};
if (useSystemUser) {
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
applyInternal(action, request, authenticatedListener);
}
} else {
applyInternal(action, request, authenticatedListener);
}
} else if (SECURITY_ACTION_MATCHER.test(action)) {
throw LicenseUtils.newComplianceException(XPackPlugin.SECURITY);
} else { } else {
chain.proceed(task, action, request, listener); chain.proceed(task, action, request, listener);
} }
return;
}
// only restore the context if it is not empty. This is needed because sometimes a response is sent to the user
// and then a cleanup action is executed (like for search without a scroll)
final boolean restoreOriginalContext = securityContext.getAuthentication() != null;
final boolean useSystemUser = AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action);
// we should always restore the original here because we forcefully changed to the system user
final ThreadContext.StoredContext toRestore = restoreOriginalContext || useSystemUser ? threadContext.newStoredContext() : () -> {};
final ActionListener<ActionResponse> signingListener = new ContextPreservingActionListener<>(toRestore, ActionListener.wrap(r -> {
try {
listener.onResponse(sign(r));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}, listener::onFailure));
ActionListener<Void> authenticatedListener = new ActionListener<Void>() {
@Override
public void onResponse(Void aVoid) {
chain.proceed(task, action, request, signingListener);
}
@Override
public void onFailure(Exception e) {
signingListener.onFailure(e);
}
};
try {
if (useSystemUser) {
securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> {
try {
applyInternal(action, request, authenticatedListener);
} catch (IOException e) {
listener.onFailure(e);
}
});
} else {
applyInternal(action, request, authenticatedListener);
}
} catch (Exception e) { } catch (Exception e) {
listener.onFailure(e); listener.onFailure(e);
} }
@ -147,8 +153,7 @@ public class SecurityActionFilter extends AbstractComponent implements ActionFil
return Integer.MIN_VALUE; return Integer.MIN_VALUE;
} }
private void applyInternal(String action, final ActionRequest request, ActionListener listener) private void applyInternal(String action, final ActionRequest request, ActionListener listener) throws IOException {
throws IOException {
/** /**
here we fallback on the system user. Internal system requests are requests that are triggered by here we fallback on the system user. Internal system requests are requests that are triggered by
the system itself (e.g. pings, update mappings, share relocation, etc...) and were not originated the system itself (e.g. pings, update mappings, share relocation, etc...) and were not originated

View File

@ -67,7 +67,7 @@ public class AuthorizationService extends AbstractComponent {
public static final Setting<Boolean> ANONYMOUS_AUTHORIZATION_EXCEPTION_SETTING = public static final Setting<Boolean> ANONYMOUS_AUTHORIZATION_EXCEPTION_SETTING =
Setting.boolSetting(setting("authc.anonymous.authz_exception"), true, Property.NodeScope); Setting.boolSetting(setting("authc.anonymous.authz_exception"), true, Property.NodeScope);
public static final String INDICES_PERMISSIONS_KEY = "_indices_permissions"; public static final String INDICES_PERMISSIONS_KEY = "_indices_permissions";
static final String ORIGINATING_ACTION_KEY = "_originating_action_name"; public static final String ORIGINATING_ACTION_KEY = "_originating_action_name";
private static final Predicate<String> MONITOR_INDEX_PREDICATE = IndexPrivilege.MONITOR.predicate(); private static final Predicate<String> MONITOR_INDEX_PREDICATE = IndexPrivilege.MONITOR.predicate();

View File

@ -10,9 +10,9 @@ import org.elasticsearch.common.util.concurrent.CountDown;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authz.permission.Role; import org.elasticsearch.xpack.security.authz.permission.Role;
import org.elasticsearch.xpack.security.user.SystemUser;
import org.elasticsearch.xpack.security.support.AutomatonPredicate; import org.elasticsearch.xpack.security.support.AutomatonPredicate;
import org.elasticsearch.xpack.security.support.Automatons; import org.elasticsearch.xpack.security.support.Automatons;
import org.elasticsearch.xpack.security.user.SystemUser;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
@ -29,11 +29,10 @@ public final class AuthorizationUtils {
* This method is used to determine if a request should be executed as the system user, even if the request already * This method is used to determine if a request should be executed as the system user, even if the request already
* has a user associated with it. * has a user associated with it.
* *
* In order for the system user to be used, one of the following conditions must be true: * In order for the user to be replaced by the system user one of the following conditions must be true:
* *
* <ul> * <ul>
* <li>the action is an internal action and no user is associated with the request</li> * <li>the action is an internal action and no user is associated with the request</li>
* <li>the action is an internal action and the system user is already associated with the request</li>
* <li>the action is an internal action and the thread context contains a non-internal action as the originating action</li> * <li>the action is an internal action and the thread context contains a non-internal action as the originating action</li>
* </ul> * </ul>
* *
@ -47,7 +46,7 @@ public final class AuthorizationUtils {
} }
Authentication authentication = threadContext.getTransient(Authentication.AUTHENTICATION_KEY); Authentication authentication = threadContext.getTransient(Authentication.AUTHENTICATION_KEY);
if (authentication == null || SystemUser.is(authentication.getUser())) { if (authentication == null) {
return true; return true;
} }
@ -62,7 +61,7 @@ public final class AuthorizationUtils {
return false; return false;
} }
public static boolean isInternalAction(String action) { private static boolean isInternalAction(String action) {
return INTERNAL_PREDICATE.test(action); return INTERNAL_PREDICATE.test(action);
} }

View File

@ -10,6 +10,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.transport.TransportInterceptor; import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.xpack.security.SecurityContext;
import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.elasticsearch.xpack.security.authz.AuthorizationUtils; import org.elasticsearch.xpack.security.authz.AuthorizationUtils;
@ -50,16 +51,18 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
private final AuthorizationService authzService; private final AuthorizationService authzService;
private final SSLService sslService; private final SSLService sslService;
private final Map<String, ServerTransportFilter> profileFilters; private final Map<String, ServerTransportFilter> profileFilters;
final XPackLicenseState licenseState; private final XPackLicenseState licenseState;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final Settings settings; private final Settings settings;
private final SecurityContext securityContext;
public SecurityServerTransportInterceptor(Settings settings, public SecurityServerTransportInterceptor(Settings settings,
ThreadPool threadPool, ThreadPool threadPool,
AuthenticationService authcService, AuthenticationService authcService,
AuthorizationService authzService, AuthorizationService authzService,
XPackLicenseState licenseState, XPackLicenseState licenseState,
SSLService sslService) { SSLService sslService,
SecurityContext securityContext) {
this.settings = settings; this.settings = settings;
this.threadPool = threadPool; this.threadPool = threadPool;
this.authcService = authcService; this.authcService = authcService;
@ -67,6 +70,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
this.licenseState = licenseState; this.licenseState = licenseState;
this.sslService = sslService; this.sslService = sslService;
this.profileFilters = initializeProfileFilters(); this.profileFilters = initializeProfileFilters();
this.securityContext = securityContext;
} }
@Override @Override
@ -75,15 +79,17 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
@Override @Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request, public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler) { TransportRequestOptions options, TransportResponseHandler<T> handler) {
// Sometimes a system action gets executed like a internal create index request or update mappings request if (licenseState.isAuthAllowed()) {
// which means that the user is copied over to system actions so we need to change the user // Sometimes a system action gets executed like a internal create index request or update mappings request
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) { // which means that the user is copied over to system actions so we need to change the user
try (ThreadContext.StoredContext ctx = threadPool.getThreadContext().stashContext()) { if (AuthorizationUtils.shouldReplaceUserWithSystem(threadPool.getThreadContext(), action)) {
final ThreadContext.StoredContext original = threadPool.getThreadContext().newStoredContext(); securityContext.executeAsUser(SystemUser.INSTANCE, (original) -> sendWithUser(node, action, request, options,
sendWithUser(node, action, request, options, new ContextRestoreResponseHandler<>(original, handler), sender); new ContextRestoreResponseHandler<>(original, handler), sender));
} else {
sendWithUser(node, action, request, options, handler, sender);
} }
} else { } else {
sendWithUser(node, action, request, options, handler, sender); sender.sendRequest(node, action, request, options, handler);
} }
} }
}; };
@ -92,11 +98,12 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
private <T extends TransportResponse> void sendWithUser(DiscoveryNode node, String action, TransportRequest request, private <T extends TransportResponse> void sendWithUser(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler, TransportRequestOptions options, TransportResponseHandler<T> handler,
AsyncSender sender) { AsyncSender sender) {
// There cannot be a request outgoing from this node that is not associated with a user.
if (securityContext.getAuthentication() == null) {
throw new IllegalStateException("there should always be a user when sending a message");
}
try { try {
// this will check if there's a user associated with the request. If there isn't,
// the system user will be attached. There cannot be a request outgoing from this
// node that is not associated with a user.
authcService.attachUserIfMissing(SystemUser.INSTANCE);
sender.sendRequest(node, action, request, options, handler); sender.sendRequest(node, action, request, options, handler);
} catch (Exception e) { } catch (Exception e) {
handler.handleException(new TransportException("failed sending request", e)); handler.handleException(new TransportException("failed sending request", e));
@ -248,14 +255,15 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
} }
/** /**
* This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the4 handle methods * This handler wrapper ensures that the response thread executes with the correct thread context. Before any of the handle methods
* are invoked we restore the context. * are invoked we restore the context.
*/ */
private static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> { static final class ContextRestoreResponseHandler<T extends TransportResponse> implements TransportResponseHandler<T> {
private final TransportResponseHandler<T> delegate; private final TransportResponseHandler<T> delegate;
private final ThreadContext.StoredContext threadContext; private final ThreadContext.StoredContext threadContext;
private ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) { // pkg private for testing
ContextRestoreResponseHandler(ThreadContext.StoredContext threadContext, TransportResponseHandler<T> delegate) {
this.delegate = delegate; this.delegate = delegate;
this.threadContext = threadContext; this.threadContext = threadContext;
} }

View File

@ -151,7 +151,7 @@ public interface ServerTransportFilter {
public void inbound(String action, TransportRequest request, TransportChannel transportChannel, ActionListener<Void> listener) public void inbound(String action, TransportRequest request, TransportChannel transportChannel, ActionListener<Void> listener)
throws IOException { throws IOException {
// TODO is ']' sufficient to mark as shard action? // TODO is ']' sufficient to mark as shard action?
boolean isInternalOrShardAction = action.startsWith("internal:") || action.endsWith("]"); final boolean isInternalOrShardAction = action.startsWith("internal:") || action.endsWith("]");
if (isInternalOrShardAction) { if (isInternalOrShardAction) {
throw authenticationError("executing internal/shard actions is considered malicious and forbidden"); throw authenticationError("executing internal/shard actions is considered malicious and forbidden");
} }

View File

@ -0,0 +1,100 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.concurrent.ThreadContext.StoredContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.Authentication.RealmRef;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.crypto.CryptoService;
import org.elasticsearch.xpack.security.user.SystemUser;
import org.elasticsearch.xpack.security.user.User;
import org.junit.Before;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicReference;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SecurityContextTests extends ESTestCase {
private boolean signHeader;
private Settings settings;
private ThreadContext threadContext;
private CryptoService cryptoService;
private SecurityContext securityContext;
@Before
public void buildSecurityContext() throws IOException {
signHeader = randomBoolean();
settings = Settings.builder()
.put("path.home", createTempDir())
.put(AuthenticationService.SIGN_USER_HEADER.getKey(), signHeader)
.build();
threadContext = new ThreadContext(settings);
cryptoService = new CryptoService(settings, new Environment(settings));
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(threadContext);
securityContext = new SecurityContext(settings, threadPool, cryptoService);
}
public void testGetAuthenticationAndUserInEmptyContext() throws IOException {
assertNull(securityContext.getAuthentication());
assertNull(securityContext.getUser());
}
public void testGetAuthenticationAndUser() throws IOException {
final User user = new User("test");
final Authentication authentication = new Authentication(user, new RealmRef("ldap", "foo", "node1"), null);
authentication.writeToContext(threadContext, cryptoService, signHeader);
assertEquals(authentication, securityContext.getAuthentication());
assertEquals(user, securityContext.getUser());
}
public void testSetUser() {
final User user = new User("test");
assertNull(securityContext.getAuthentication());
assertNull(securityContext.getUser());
securityContext.setUser(user);
assertEquals(user, securityContext.getUser());
IllegalStateException e = expectThrows(IllegalStateException.class,
() -> securityContext.setUser(randomFrom(user, SystemUser.INSTANCE)));
assertEquals("authentication is already present in the context", e.getMessage());
}
public void testExecuteAsUser() throws IOException {
final User original;
if (randomBoolean()) {
original = new User("test");
final Authentication authentication = new Authentication(original, new RealmRef("ldap", "foo", "node1"), null);
authentication.writeToContext(threadContext, cryptoService, signHeader);
} else {
original = null;
}
final User executionUser = new User("executor");
final AtomicReference<StoredContext> contextAtomicReference = new AtomicReference<>();
securityContext.executeAsUser(executionUser, (originalCtx) -> {
assertEquals(executionUser, securityContext.getUser());
contextAtomicReference.set(originalCtx);
});
final User userAfterExecution = securityContext.getUser();
assertEquals(original, userAfterExecution);
StoredContext originalContext = contextAtomicReference.get();
assertNotNull(originalContext);
originalContext.restore();
assertEquals(original, securityContext.getUser());
}
}

View File

@ -32,11 +32,13 @@ public class AuthorizationUtilsTests extends ESTestCase {
assertThat(AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, randomFrom("indices:foo", "cluster:bar")), is(false)); assertThat(AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, randomFrom("indices:foo", "cluster:bar")), is(false));
} }
public void testSystemUserSwitchWithNullorSystemUser() { public void testSystemUserSwitchWithSystemUser() {
if (randomBoolean()) { threadContext.putTransient(Authentication.AUTHENTICATION_KEY,
threadContext.putTransient(Authentication.AUTHENTICATION_KEY, new Authentication(SystemUser.INSTANCE, new RealmRef("test", "test", "foo"), null));
new Authentication(SystemUser.INSTANCE, new RealmRef("test", "test", "foo"), null)); assertThat(AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, "internal:something"), is(false));
} }
public void testSystemUserSwitchWithNullUser() {
assertThat(AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, "internal:something"), is(true)); assertThat(AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, "internal:something"), is(true));
} }

View File

@ -0,0 +1,211 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor.AsyncSender;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestOptions;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponse.Empty;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.xpack.security.SecurityContext;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.Authentication.RealmRef;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.elasticsearch.xpack.security.crypto.CryptoService;
import org.elasticsearch.xpack.security.transport.SecurityServerTransportInterceptor.ContextRestoreResponseHandler;
import org.elasticsearch.xpack.security.user.SystemUser;
import org.elasticsearch.xpack.security.user.User;
import org.elasticsearch.xpack.ssl.SSLService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
public class SecurityServerTransportInterceptorTests extends ESTestCase {
private Settings settings;
private ThreadPool threadPool;
private ThreadContext threadContext;
private XPackLicenseState xPackLicenseState;
private CryptoService cryptoService;
private SecurityContext securityContext;
@Override
public void setUp() throws Exception {
super.setUp();
settings = Settings.builder().put("path.home", createTempDir()).build();
threadPool = mock(ThreadPool.class);
threadContext = new ThreadContext(settings);
when(threadPool.getThreadContext()).thenReturn(threadContext);
cryptoService = new CryptoService(settings, new Environment(settings));
securityContext = spy(new SecurityContext(settings, threadPool, cryptoService));
xPackLicenseState = mock(XPackLicenseState.class);
when(xPackLicenseState.isAuthAllowed()).thenReturn(true);
}
public void testSendAsyncUnlicensed() {
SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor(settings, threadPool,
mock(AuthenticationService.class), mock(AuthorizationService.class), xPackLicenseState, mock(SSLService.class),
securityContext);
when(xPackLicenseState.isAuthAllowed()).thenReturn(false);
AtomicBoolean calledWrappedSender = new AtomicBoolean(false);
AsyncSender sender = interceptor.interceptSender(new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler) {
if (calledWrappedSender.compareAndSet(false, true) == false) {
fail("sender called more than once!");
}
}
});
sender.sendRequest(null, null, null, null, null);
assertTrue(calledWrappedSender.get());
verify(xPackLicenseState).isAuthAllowed();
verifyNoMoreInteractions(xPackLicenseState);
verifyZeroInteractions(securityContext);
}
public void testSendAsync() throws Exception {
final User user = new User("test");
final Authentication authentication = new Authentication(user, new RealmRef("ldap", "foo", "node1"), null);
authentication.writeToContext(threadContext, cryptoService, AuthenticationService.SIGN_USER_HEADER.get(settings));
SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor(settings, threadPool,
mock(AuthenticationService.class), mock(AuthorizationService.class), xPackLicenseState, mock(SSLService.class),
securityContext);
AtomicBoolean calledWrappedSender = new AtomicBoolean(false);
AtomicReference<User> sendingUser = new AtomicReference<>();
AsyncSender sender = interceptor.interceptSender(new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler) {
if (calledWrappedSender.compareAndSet(false, true) == false) {
fail("sender called more than once!");
}
sendingUser.set(securityContext.getUser());
}
});
sender.sendRequest(null, "indices:foo", null, null, null);
assertTrue(calledWrappedSender.get());
assertEquals(user, sendingUser.get());
assertEquals(user, securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed();
verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class));
verifyNoMoreInteractions(xPackLicenseState);
}
public void testSendAsyncSwitchToSystem() throws Exception {
final User user = new User("test");
final Authentication authentication = new Authentication(user, new RealmRef("ldap", "foo", "node1"), null);
authentication.writeToContext(threadContext, cryptoService, AuthenticationService.SIGN_USER_HEADER.get(settings));
threadContext.putTransient(AuthorizationService.ORIGINATING_ACTION_KEY, "indices:foo");
SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor(settings, threadPool,
mock(AuthenticationService.class), mock(AuthorizationService.class), xPackLicenseState, mock(SSLService.class),
securityContext);
AtomicBoolean calledWrappedSender = new AtomicBoolean(false);
AtomicReference<User> sendingUser = new AtomicReference<>();
AsyncSender sender = interceptor.interceptSender(new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler) {
if (calledWrappedSender.compareAndSet(false, true) == false) {
fail("sender called more than once!");
}
sendingUser.set(securityContext.getUser());
}
});
sender.sendRequest(null, "internal:foo", null, null, null);
assertTrue(calledWrappedSender.get());
assertNotEquals(user, sendingUser.get());
assertEquals(SystemUser.INSTANCE, sendingUser.get());
assertEquals(user, securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed();
verify(securityContext).executeAsUser(any(User.class), any(Consumer.class));
verifyNoMoreInteractions(xPackLicenseState);
}
public void testSendWithoutUser() throws Exception {
SecurityServerTransportInterceptor interceptor = new SecurityServerTransportInterceptor(settings, threadPool,
mock(AuthenticationService.class), mock(AuthorizationService.class), xPackLicenseState, mock(SSLService.class),
securityContext);
assertNull(securityContext.getUser());
AsyncSender sender = interceptor.interceptSender(new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(DiscoveryNode node, String action, TransportRequest request,
TransportRequestOptions options, TransportResponseHandler<T> handler) {
fail("sender should not be called!");
}
});
IllegalStateException e =
expectThrows(IllegalStateException.class, () -> sender.sendRequest(null, "indices:foo", null, null, null));
assertEquals("there should always be a user when sending a message", e.getMessage());
assertNull(securityContext.getUser());
verify(xPackLicenseState).isAuthAllowed();
verify(securityContext, never()).executeAsUser(any(User.class), any(Consumer.class));
verifyNoMoreInteractions(xPackLicenseState);
}
public void testContextRestoreResponseHandler() throws Exception {
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
threadContext.putTransient("foo", "bar");
threadContext.putHeader("key", "value");
try (ThreadContext.StoredContext storedContext = threadContext.stashContext()) {
threadContext.putTransient("foo", "different_bar");
threadContext.putHeader("key", "value2");
TransportResponseHandler<Empty> handler = new ContextRestoreResponseHandler<>(storedContext,
new TransportResponseHandler<Empty>() {
@Override
public Empty newInstance() {
return Empty.INSTANCE;
}
@Override
public void handleResponse(Empty response) {
assertEquals("bar", threadContext.getTransient("foo"));
assertEquals("value", threadContext.getHeader("key"));
}
@Override
public void handleException(TransportException exp) {
assertEquals("bar", threadContext.getTransient("foo"));
assertEquals("value", threadContext.getHeader("key"));
}
@Override
public String executor() {
return null;
}
});
handler.handleResponse(null);
handler.handleException(null);
}
}
}

View File

@ -12,12 +12,14 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.transport.TransportChannel; import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.action.SecurityActionMapper;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportSettings; import org.elasticsearch.transport.TransportSettings;
import org.elasticsearch.xpack.security.authc.Authentication.RealmRef;
import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService; import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.elasticsearch.xpack.security.authz.permission.Role;
import org.elasticsearch.xpack.security.authz.permission.SuperuserRole;
import org.elasticsearch.xpack.security.user.SystemUser; import org.elasticsearch.xpack.security.user.SystemUser;
import org.elasticsearch.xpack.security.user.User; import org.elasticsearch.xpack.security.user.User;
import org.elasticsearch.xpack.security.user.XPackUser; import org.elasticsearch.xpack.security.user.XPackUser;
@ -26,21 +28,23 @@ import org.junit.Before;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import static org.elasticsearch.mock.orig.Mockito.times;
import static org.elasticsearch.xpack.security.support.Exceptions.authenticationError; import static org.elasticsearch.xpack.security.support.Exceptions.authenticationError;
import static org.elasticsearch.xpack.security.support.Exceptions.authorizationError; import static org.elasticsearch.xpack.security.support.Exceptions.authorizationError;
import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any; import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class ServerTransportFilterTests extends ESTestCase { public class ServerTransportFilterTests extends ESTestCase {
private AuthenticationService authcService; private AuthenticationService authcService;
private AuthorizationService authzService; private AuthorizationService authzService;
private ServerTransportFilter filter;
private TransportChannel channel; private TransportChannel channel;
@Before @Before
@ -49,8 +53,6 @@ public class ServerTransportFilterTests extends ESTestCase {
authzService = mock(AuthorizationService.class); authzService = mock(AuthorizationService.class);
channel = mock(TransportChannel.class); channel = mock(TransportChannel.class);
when(channel.getProfileName()).thenReturn(TransportSettings.DEFAULT_PROFILE); when(channel.getProfileName()).thenReturn(TransportSettings.DEFAULT_PROFILE);
filter = new ServerTransportFilter.NodeProfile(authcService, authzService,
new ThreadContext(Settings.EMPTY), false);
} }
public void testInbound() throws Exception { public void testInbound() throws Exception {
@ -58,6 +60,7 @@ public class ServerTransportFilterTests extends ESTestCase {
Authentication authentication = mock(Authentication.class); Authentication authentication = mock(Authentication.class);
when(authentication.getUser()).thenReturn(SystemUser.INSTANCE); when(authentication.getUser()).thenReturn(SystemUser.INSTANCE);
when(authcService.authenticate("_action", request, null)).thenReturn(authentication); when(authcService.authenticate("_action", request, null)).thenReturn(authentication);
ServerTransportFilter filter = getClientOrNodeFilter();
PlainActionFuture future = new PlainActionFuture(); PlainActionFuture future = new PlainActionFuture();
filter.inbound("_action", request, channel, future); filter.inbound("_action", request, channel, future);
//future.get(); // don't block it's not called really just mocked //future.get(); // don't block it's not called really just mocked
@ -67,6 +70,7 @@ public class ServerTransportFilterTests extends ESTestCase {
public void testInboundAuthenticationException() throws Exception { public void testInboundAuthenticationException() throws Exception {
TransportRequest request = mock(TransportRequest.class); TransportRequest request = mock(TransportRequest.class);
doThrow(authenticationError("authc failed")).when(authcService).authenticate("_action", request, null); doThrow(authenticationError("authc failed")).when(authcService).authenticate("_action", request, null);
ServerTransportFilter filter = getClientOrNodeFilter();
try { try {
PlainActionFuture future = new PlainActionFuture(); PlainActionFuture future = new PlainActionFuture();
filter.inbound("_action", request, channel, future); filter.inbound("_action", request, channel, future);
@ -79,6 +83,7 @@ public class ServerTransportFilterTests extends ESTestCase {
} }
public void testInboundAuthorizationException() throws Exception { public void testInboundAuthorizationException() throws Exception {
ServerTransportFilter filter = getClientOrNodeFilter();
TransportRequest request = mock(TransportRequest.class); TransportRequest request = mock(TransportRequest.class);
Authentication authentication = mock(Authentication.class); Authentication authentication = mock(Authentication.class);
when(authcService.authenticate("_action", request, null)).thenReturn(authentication); when(authcService.authenticate("_action", request, null)).thenReturn(authentication);
@ -101,4 +106,57 @@ public class ServerTransportFilterTests extends ESTestCase {
assertThat(e.getMessage(), equalTo("authz failed")); assertThat(e.getMessage(), equalTo("authz failed"));
} }
} }
public void testClientProfileRejectsNodeActions() throws Exception {
TransportRequest request = mock(TransportRequest.class);
ServerTransportFilter filter = getClientFilter();
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> filter.inbound("internal:foo/bar", request, channel, new PlainActionFuture<>()));
assertEquals("executing internal/shard actions is considered malicious and forbidden", e.getMessage());
e = expectThrows(ElasticsearchSecurityException.class,
() -> filter.inbound("indices:action" + randomFrom("[s]", "[p]", "[r]", "[n]", "[s][p]", "[s][r]", "[f]"),
request, channel, new PlainActionFuture<>()));
assertEquals("executing internal/shard actions is considered malicious and forbidden", e.getMessage());
verifyZeroInteractions(authcService);
}
public void testNodeProfileAllowsNodeActions() throws Exception {
final String internalAction = "internal:foo/bar";
final String nodeOrShardAction = "indices:action" + randomFrom("[s]", "[p]", "[r]", "[n]", "[s][p]", "[s][r]", "[f]");
ServerTransportFilter filter = getNodeFilter();
TransportRequest request = mock(TransportRequest.class);
Authentication authentication = new Authentication(new User("test", "superuser"), new RealmRef("test", "test", "node1"), null);
final Collection<Role> userRoles = Collections.singletonList(SuperuserRole.INSTANCE);
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[1];
callback.onResponse(authentication.getUser().equals(i.getArguments()[0]) ? userRoles : Collections.emptyList());
return Void.TYPE;
}).when(authzService).roles(any(User.class), any(ActionListener.class));
when(authcService.authenticate(internalAction, request, null)).thenReturn(authentication);
when(authcService.authenticate(nodeOrShardAction, request, null)).thenReturn(authentication);
filter.inbound(internalAction, request, channel, new PlainActionFuture<>());
verify(authcService).authenticate(internalAction, request, null);
verify(authzService).roles(eq(authentication.getUser()), any(ActionListener.class));
verify(authzService).authorize(authentication, internalAction, request, userRoles, Collections.emptyList());
filter.inbound(nodeOrShardAction, request, channel, new PlainActionFuture<>());
verify(authcService).authenticate(nodeOrShardAction, request, null);
verify(authzService, times(2)).roles(eq(authentication.getUser()), any(ActionListener.class));
verify(authzService).authorize(authentication, nodeOrShardAction, request, userRoles, Collections.emptyList());
verifyNoMoreInteractions(authcService, authzService);
}
private ServerTransportFilter getClientOrNodeFilter() {
return randomBoolean() ? getNodeFilter() : getClientFilter();
}
private ServerTransportFilter.ClientProfile getClientFilter() {
return new ServerTransportFilter.ClientProfile(authcService, authzService, new ThreadContext(Settings.EMPTY), false);
}
private ServerTransportFilter.NodeProfile getNodeFilter() {
return new ServerTransportFilter.NodeProfile(authcService, authzService, new ThreadContext(Settings.EMPTY), false);
}
} }

View File

@ -1,358 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.transport;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.inject.AbstractModule;
import org.elasticsearch.common.inject.Module;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.plugins.NetworkPlugin;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.PluginsService;
import org.elasticsearch.script.ScriptService;
import org.elasticsearch.search.SearchRequestParsers;
import org.elasticsearch.transport.MockTcpTransportPlugin;
import org.elasticsearch.transport.TransportInterceptor;
import org.elasticsearch.watcher.ResourceWatcherService;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authz.AuthorizationService;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.ESIntegTestCase.ClusterScope;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportRequest;
import org.elasticsearch.transport.TransportRequestHandler;
import org.elasticsearch.transport.TransportResponse;
import org.elasticsearch.transport.TransportResponseHandler;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.transport.TransportSettings;
import org.elasticsearch.xpack.security.user.SystemUser;
import org.elasticsearch.xpack.ssl.SSLService;
import org.mockito.InOrder;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.test.ESIntegTestCase.Scope.SUITE;
import static org.hamcrest.Matchers.equalTo;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.isA;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ClusterScope(scope = SUITE, numDataNodes = 0)
public class TransportFilterTests extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> getMockPlugins() {
return Collections.emptyList();
}
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return Arrays.asList(InternalPluginServerTransportServiceInterceptor.TestPlugin.class, MockTcpTransportPlugin.class);
}
@Override
protected Collection<Class<? extends Plugin>> transportClientPlugins() {
return Collections.singleton(MockTcpTransportPlugin.class);
}
public void test() throws Exception {
String source = internalCluster().startNode();
DiscoveryNode sourceNode = internalCluster().getInstance(ClusterService.class, source).localNode();
TransportService sourceService = internalCluster().getInstance(TransportService.class, source);
InternalPluginServerTransportServiceInterceptor sourceInterceptor = internalCluster().getInstance(PluginsService.class, source)
.filterPlugins(InternalPluginServerTransportServiceInterceptor.TestPlugin.class).stream().findFirst().get().interceptor;
String target = internalCluster().startNode();
DiscoveryNode targetNode = internalCluster().getInstance(ClusterService.class, target).localNode();
TransportService targetService = internalCluster().getInstance(TransportService.class, target);
InternalPluginServerTransportServiceInterceptor targetInterceptor = internalCluster().getInstance(PluginsService.class, target)
.filterPlugins(InternalPluginServerTransportServiceInterceptor.TestPlugin.class).stream().findFirst().get().interceptor;
CountDownLatch latch = new CountDownLatch(2);
targetService.registerRequestHandler("_action", Request::new, ThreadPool.Names.SAME,
new RequestHandler(new Response("trgt_to_src"), latch));
sourceService.sendRequest(targetNode, "_action", new Request("src_to_trgt"),
new ResponseHandler(new Response("trgt_to_src"), latch));
await(latch);
latch = new CountDownLatch(2);
sourceService.registerRequestHandler("_action", Request::new, ThreadPool.Names.SAME,
new RequestHandler(new Response("src_to_trgt"), latch));
targetService.sendRequest(sourceNode, "_action", new Request("trgt_to_src"),
new ResponseHandler(new Response("src_to_trgt"), latch));
await(latch);
ServerTransportFilter sourceServerFilter = sourceInterceptor.transportFilter(TransportSettings.DEFAULT_PROFILE);
ServerTransportFilter targetServerFilter = targetInterceptor.transportFilter(TransportSettings.DEFAULT_PROFILE);
AuthenticationService sourceAuth = internalCluster().getInstance(AuthenticationService.class, source);
AuthenticationService targetAuth = internalCluster().getInstance(AuthenticationService.class, target);
InOrder inOrder = inOrder(sourceAuth, targetServerFilter, targetAuth, sourceServerFilter);
inOrder.verify(sourceAuth).attachUserIfMissing(SystemUser.INSTANCE);
inOrder.verify(targetServerFilter).inbound(eq("_action"), eq(new Request("src_to_trgt")), isA(TransportChannel.class),
any(ActionListener.class));
inOrder.verify(targetAuth).attachUserIfMissing(SystemUser.INSTANCE);
inOrder.verify(sourceServerFilter).inbound(eq("_action"), eq(new Request("trgt_to_src")), isA(TransportChannel.class),
any(ActionListener.class));
}
public static class InternalPlugin extends Plugin {
@Override
public Collection<Module> createGuiceModules() {
return Collections.singletonList(new TestTransportFilterModule());
}
}
public static class TestTransportFilterModule extends AbstractModule {
@Override
protected void configure() {
bind(AuthenticationService.class).toInstance(mock(AuthenticationService.class));
bind(AuthorizationService.class).toInstance(mock(AuthorizationService.class));
}
}
public static class Request extends TransportRequest {
private String msg;
public Request() {
}
Request(String msg) {
this.msg = msg;
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
msg = in.readString();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(msg);
}
@Override
public String toString() {
return msg;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Request request = (Request) o;
if (!msg.equals(request.msg)) return false;
return true;
}
@Override
public int hashCode() {
return msg.hashCode();
}
}
static class Response extends TransportResponse {
private String msg;
Response() {
}
Response(String msg) {
this.msg = msg;
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
msg = in.readString();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(msg);
}
@Override
public String toString() {
return msg;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Response response = (Response) o;
if (!msg.equals(response.msg)) return false;
return true;
}
@Override
public int hashCode() {
return msg.hashCode();
}
}
static class RequestHandler implements TransportRequestHandler<Request> {
private final Response response;
private final CountDownLatch latch;
RequestHandler(Response response, CountDownLatch latch) {
this.response = response;
this.latch = latch;
}
@Override
public void messageReceived(Request request, TransportChannel channel) throws Exception {
channel.sendResponse(response);
latch.countDown();
}
}
class ResponseHandler implements TransportResponseHandler<Response> {
private final Response response;
private final CountDownLatch latch;
ResponseHandler(Response response, CountDownLatch latch) {
this.response = response;
this.latch = latch;
}
@Override
public Response newInstance() {
return new Response();
}
@Override
public void handleResponse(Response response) {
assertThat(response, equalTo(this.response));
latch.countDown();
}
@Override
public void handleException(TransportException exp) {
logger.error("execution of request failed", exp);
fail("execution of request failed");
}
@Override
public String executor() {
return ThreadPool.Names.SAME;
}
}
private static void await(CountDownLatch latch) throws Exception {
if (!latch.await(5, TimeUnit.SECONDS)) {
fail("waiting too long for request");
}
}
// Sub class the security transport to always inject a mock for testing
public static class InternalPluginServerTransportServiceInterceptor extends SecurityServerTransportInterceptor {
public static class TestPlugin extends Plugin implements NetworkPlugin {
AuthenticationService authenticationService = mock(AuthenticationService.class);
AuthorizationService authorizationService = mock(AuthorizationService.class);
InternalPluginServerTransportServiceInterceptor interceptor;
@Override
public Collection<Object> createComponents(Client client, ClusterService clusterService, ThreadPool threadPool,
ResourceWatcherService resourceWatcherService, ScriptService scriptService,
SearchRequestParsers searchRequestParsers) {
interceptor = new InternalPluginServerTransportServiceInterceptor(clusterService.getSettings(), threadPool,
authenticationService, authorizationService);
return Collections.emptyList();
}
@Override
public Collection<Module> createGuiceModules() {
return Collections.singleton((Module) binder -> {
binder.bind(AuthenticationService.class).toInstance(authenticationService);
binder.bind(AuthorizationService.class).toInstance(authorizationService);
});
}
@Override
public List<TransportInterceptor> getTransportInterceptors(NamedWriteableRegistry namedWriteableRegistry) {
return Collections.singletonList(new TransportInterceptor() {
@Override
public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(String action, String executor,
TransportRequestHandler<T> actualHandler) {
return interceptor.interceptHandler(action, executor, actualHandler);
}
@Override
public AsyncSender interceptSender(AsyncSender sender) {
return interceptor.interceptSender(sender);
}
});
}
}
public InternalPluginServerTransportServiceInterceptor(Settings settings, ThreadPool threadPool,
AuthenticationService authenticationService,
AuthorizationService authorizationService) {
super(settings, threadPool,authenticationService, authorizationService, mock(XPackLicenseState.class),
mock(SSLService.class));
when(licenseState.isAuthAllowed()).thenReturn(true);
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[3];
callback.onResponse(null);
return Void.TYPE;
}).when(authorizationService).roles(any(), any(ActionListener.class));
}
@Override
protected Map<String, ServerTransportFilter> initializeProfileFilters() {
ServerTransportFilter.NodeProfile mock = mock(ServerTransportFilter.NodeProfile.class);
try {
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[3];
callback.onResponse(null);
return Void.TYPE;
}).when(mock).inbound(any(), any(), any(), any(ActionListener.class));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
return Collections.singletonMap(TransportSettings.DEFAULT_PROFILE,
mock);
}
}
}