Convert security rest filter to rest handler wrapper (elastic/elasticsearch#4234)

* Convert security rest filter to rest handler wrapper

This is the xpack side of elastic/elasticsearchelastic/elasticsearch#21905

Original commit: elastic/x-pack-elasticsearch@38bfa771b6
This commit is contained in:
Ryan Ernst 2016-12-02 14:55:10 -08:00 committed by GitHub
parent f1a4a2fb73
commit 923926ef28
6 changed files with 77 additions and 79 deletions

View File

@ -22,6 +22,7 @@ import org.elasticsearch.common.settings.ClusterSettings;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.http.HttpServerTransport;
import org.elasticsearch.index.IndexModule;
@ -102,6 +103,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
public class XPackPlugin extends Plugin implements ScriptPlugin, ActionPlugin, IngestPlugin, NetworkPlugin {
@ -465,4 +467,8 @@ public class XPackPlugin extends Plugin implements ScriptPlugin, ActionPlugin, I
return security.getHttpTransports(settings, threadPool, bigArrays, circuitBreakerService, namedWriteableRegistry, networkService);
}
@Override
public UnaryOperator<RestHandler> getRestHandlerWrapper(ThreadContext threadContext) {
return security.getRestHandlerWrapper(threadContext);
}
}

View File

@ -100,7 +100,7 @@ import org.elasticsearch.xpack.security.authz.store.FileRolesStore;
import org.elasticsearch.xpack.security.authz.store.NativeRolesStore;
import org.elasticsearch.xpack.security.authz.store.ReservedRolesStore;
import org.elasticsearch.xpack.security.crypto.CryptoService;
import org.elasticsearch.xpack.security.rest.SecurityRestModule;
import org.elasticsearch.xpack.security.rest.SecurityRestFilter;
import org.elasticsearch.xpack.security.rest.action.RestAuthenticateAction;
import org.elasticsearch.xpack.security.rest.action.realm.RestClearRealmCacheAction;
import org.elasticsearch.xpack.security.rest.action.role.RestClearRolesCacheAction;
@ -134,6 +134,7 @@ import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.function.UnaryOperator;
import java.util.stream.Collectors;
import static java.util.Collections.emptyList;
@ -165,6 +166,7 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
* an instance of TransportInterceptor way earlier before createComponents is called. */
private final SetOnce<TransportInterceptor> securityIntercepter = new SetOnce<>();
private final SetOnce<IPFilter> ipFilter = new SetOnce<>();
private final SetOnce<AuthenticationService> authcService = new SetOnce<>();
public Security(Settings settings, Environment env, XPackLicenseState licenseState, SSLService sslService) throws IOException {
this.settings = settings;
@ -225,7 +227,6 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
b.bind(AuditTrail.class).to(AuditTrailService.class); // interface used by some actions...
}
});
modules.add(new SecurityRestModule(settings));
modules.add(new SecurityActionModule(settings));
return modules;
}
@ -310,9 +311,9 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
logger.debug("Using authentication failure handler from extension [" + extensionName + "]");
}
final AuthenticationService authcService = new AuthenticationService(settings, realms, auditTrailService,
cryptoService, failureHandler, threadPool, anonymousUser);
components.add(authcService);
authcService.set(new AuthenticationService(settings, realms, auditTrailService,
cryptoService, failureHandler, threadPool, anonymousUser));
components.add(authcService.get());
final FileRolesStore fileRolesStore = new FileRolesStore(settings, env, resourceWatcherService);
final NativeRolesStore nativeRolesStore = new NativeRolesStore(settings, client);
@ -332,7 +333,7 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
ipFilter.set(new IPFilter(settings, auditTrailService, clusterService.getClusterSettings(), licenseState));
components.add(ipFilter.get());
DestructiveOperations destructiveOperations = new DestructiveOperations(settings, clusterService.getClusterSettings());
securityIntercepter.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService, authzService, licenseState,
securityIntercepter.set(new SecurityServerTransportInterceptor(settings, threadPool, authcService.get(), authzService, licenseState,
sslService, securityContext, destructiveOperations));
return components;
}
@ -725,4 +726,12 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin {
threadPool));
}
@Override
public UnaryOperator<RestHandler> getRestHandlerWrapper(ThreadContext threadContext) {
if (enabled == false || transportClientMode) {
return null;
}
return handler -> new SecurityRestFilter(settings, licenseState, sslService, threadContext, authcService.get(), handler);
}
}

View File

@ -7,59 +7,49 @@ package org.elasticsearch.xpack.security.rest;
import io.netty.handler.ssl.SslHandler;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.logging.log4j.util.Supplier;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.logging.Loggers;
import org.elasticsearch.common.logging.ESLoggerFactory;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.http.netty4.Netty4HttpRequest;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestFilter;
import org.elasticsearch.rest.RestFilterChain;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestRequest.Method;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.transport.ServerTransportFilter;
import org.elasticsearch.xpack.ssl.SSLService;
import static org.elasticsearch.xpack.XPackSettings.HTTP_SSL_ENABLED;
public class SecurityRestFilter extends RestFilter {
public class SecurityRestFilter implements RestHandler {
private static final Logger logger = ESLoggerFactory.getLogger(SecurityRestFilter.class);
private final RestHandler restHandler;
private final AuthenticationService service;
private final Logger logger;
private final XPackLicenseState licenseState;
private final ThreadContext threadContext;
private final RestController restController;
private final boolean extractClientCertificate;
@Inject
public SecurityRestFilter(AuthenticationService service, RestController controller, Settings settings,
ThreadPool threadPool, XPackLicenseState licenseState, SSLService sslService) {
public SecurityRestFilter(Settings settings, XPackLicenseState licenseState, SSLService sslService,
ThreadContext threadContext, AuthenticationService service, RestHandler restHandler) {
this.restHandler = restHandler;
this.service = service;
this.licenseState = licenseState;
this.threadContext = threadPool.getThreadContext();
this.logger = Loggers.getLogger(getClass(), settings);
this.threadContext = threadContext;
final boolean ssl = HTTP_SSL_ENABLED.get(settings);
Settings httpSSLSettings = SSLService.getHttpTransportSSLSettings(settings);
this.extractClientCertificate = ssl && sslService.isSSLClientAuthEnabled(httpSSLSettings);
controller.registerFilter(this);
this.restController = controller;
}
@Override
public int order() {
return Integer.MIN_VALUE;
}
@Override
public void process(RestRequest request, RestChannel channel, NodeClient client, RestFilterChain filterChain) throws Exception {
public void handleRequest(RestRequest request, RestChannel channel, NodeClient client) throws Exception {
if (licenseState.isAuthAllowed() && request.method() != Method.OPTIONS) {
// CORS - allow for preflight unauthenticated OPTIONS request
if (extractClientCertificate) {
@ -68,12 +58,20 @@ public class SecurityRestFilter extends RestFilter {
assert handler != null;
ServerTransportFilter.extactClientCertificates(logger, threadContext, handler.engine(), nettyHttpRequest.getChannel());
}
service.authenticate(request, ActionListener.wrap((authentication) -> {
service.authenticate(request, ActionListener.wrap(
authentication -> {
RemoteHostHeader.process(request, threadContext);
filterChain.continueProcessing(request, channel, client);
}, (e) -> restController.sendErrorResponse(request, channel, e)));
restHandler.handleRequest(request, channel, client);
}, e -> {
try {
channel.sendResponse(new BytesRestResponse(channel, e));
} catch (Exception inner) {
inner.addSuppressed(e);
logger.error((Supplier<?>) () -> new ParameterizedMessage("failed to send failure response for uri [{}]", request.uri()), inner);
}
}));
} else {
filterChain.continueProcessing(request, channel, client);
restHandler.handleRequest(request, channel, client);
}
}
}

View File

@ -1,21 +0,0 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.rest;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.xpack.security.support.AbstractSecurityModule;
public class SecurityRestModule extends AbstractSecurityModule.Node {
public SecurityRestModule(Settings settings) {
super(settings);
}
@Override
protected void configureNode() {
bind(SecurityRestFilter.class).asEagerSingleton();
}
}

View File

@ -212,7 +212,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
}
assert filter != null;
final Thread executingThread = Thread.currentThread();
Consumer<Void> consumer = (x) -> {
ActionListener.CheckedConsumer<Void> consumer = (x) -> {
final Executor executor;
if (executingThread == Thread.currentThread()) {
// only fork off if we get called on another thread this means we moved to

View File

@ -8,17 +8,21 @@ package org.elasticsearch.xpack.security.rest;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestFilterChain;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.BytesRestResponse;
import org.elasticsearch.rest.RestChannel;
import org.elasticsearch.rest.RestHandler;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.security.authc.Authentication;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.ssl.SSLService;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import static org.elasticsearch.xpack.security.support.Exceptions.authenticationError;
import static org.mockito.Matchers.any;
@ -26,29 +30,28 @@ import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
public class SecurityRestFilterTests extends ESTestCase {
private AuthenticationService authcService;
private RestChannel channel;
private RestFilterChain chain;
private SecurityRestFilter filter;
private XPackLicenseState licenseState;
private RestController restController;
private RestHandler restHandler;
@Before
public void init() throws Exception {
authcService = mock(AuthenticationService.class);
restController = mock(RestController.class);
channel = mock(RestChannel.class);
chain = mock(RestFilterChain.class);
licenseState = mock(XPackLicenseState.class);
when(licenseState.isAuthAllowed()).thenReturn(true);
restHandler = mock(RestHandler.class);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
filter = new SecurityRestFilter(authcService, restController, Settings.EMPTY, threadPool, licenseState, mock(SSLService.class));
verify(restController).registerFilter(filter);
filter = new SecurityRestFilter(Settings.EMPTY, licenseState, mock(SSLService.class),
threadPool.getThreadContext(), authcService, restHandler);
}
public void testProcess() throws Exception {
@ -56,20 +59,20 @@ public class SecurityRestFilterTests extends ESTestCase {
Authentication authentication = mock(Authentication.class);
doAnswer((i) -> {
ActionListener callback =
(ActionListener) i.getArguments()[1];
(ActionListener) i.getArguments()[1];
callback.onResponse(authentication);
return Void.TYPE;
}).when(authcService).authenticate(eq(request), any(ActionListener.class));
filter.process(request, channel, null, chain);
verify(chain).continueProcessing(request, channel, null);
filter.handleRequest(request, channel, null);
verify(restHandler).handleRequest(request, channel, null);
verifyZeroInteractions(channel);
}
public void testProcessBasicLicense() throws Exception {
RestRequest request = mock(RestRequest.class);
when(licenseState.isAuthAllowed()).thenReturn(false);
filter.process(request, channel, null, chain);
verify(chain).continueProcessing(request, channel, null);
filter.handleRequest(request, channel, null);
verifyZeroInteractions(restHandler);
verifyZeroInteractions(channel, authcService);
}
@ -82,17 +85,20 @@ public class SecurityRestFilterTests extends ESTestCase {
callback.onFailure(exception);
return Void.TYPE;
}).when(authcService).authenticate(eq(request), any(ActionListener.class));
filter.process(request, channel, null, chain);
verify(restController).sendErrorResponse(request, channel, exception);
verifyZeroInteractions(channel);
verifyZeroInteractions(chain);
when(channel.request()).thenReturn(request);
when(channel.newErrorBuilder()).thenReturn(JsonXContent.contentBuilder());
filter.handleRequest(request, channel, null);
ArgumentCaptor<BytesRestResponse> response = ArgumentCaptor.forClass(BytesRestResponse.class);
verify(channel).sendResponse(response.capture());
assertEquals(RestStatus.UNAUTHORIZED, response.getValue().status());
verifyZeroInteractions(restHandler);
}
public void testProcessOptionsMethod() throws Exception {
RestRequest request = mock(RestRequest.class);
when(request.method()).thenReturn(RestRequest.Method.OPTIONS);
filter.process(request, channel, null, chain);
verify(chain).continueProcessing(request, channel, null);
filter.handleRequest(request, channel, null);
verifyZeroInteractions(restHandler);
verifyZeroInteractions(channel);
verifyZeroInteractions(authcService);
}