diff --git a/src/main/java/org/elasticsearch/shield/SecurityFilter.java b/src/main/java/org/elasticsearch/shield/SecurityFilter.java index 685eb10a884..f82a59d4a97 100644 --- a/src/main/java/org/elasticsearch/shield/SecurityFilter.java +++ b/src/main/java/org/elasticsearch/shield/SecurityFilter.java @@ -26,7 +26,7 @@ import org.elasticsearch.shield.authz.AuthorizationService; import org.elasticsearch.shield.authz.SystemRole; import org.elasticsearch.shield.key.KeyService; import org.elasticsearch.shield.key.SignatureException; -import org.elasticsearch.shield.transport.TransportFilter; +import org.elasticsearch.shield.transport.ServerTransportFilter; import org.elasticsearch.transport.TransportRequest; import java.util.ArrayList; @@ -140,17 +140,17 @@ public class SecurityFilter extends AbstractComponent { } } - public static class Transport extends TransportFilter.Base { + public static class ServerTransport implements ServerTransportFilter { private final SecurityFilter filter; @Inject - public Transport(SecurityFilter filter) { + public ServerTransport(SecurityFilter filter) { this.filter = filter; } @Override - public void inboundRequest(String action, TransportRequest request) { + public void inbound(String action, TransportRequest request) { filter.authenticateAndAuthorize(action, request); } } diff --git a/src/main/java/org/elasticsearch/shield/transport/SecuredTransportModule.java b/src/main/java/org/elasticsearch/shield/transport/SecuredTransportModule.java index 084e5a0af20..15eed1e48e1 100644 --- a/src/main/java/org/elasticsearch/shield/transport/SecuredTransportModule.java +++ b/src/main/java/org/elasticsearch/shield/transport/SecuredTransportModule.java @@ -53,11 +53,11 @@ public class SecuredTransportModule extends AbstractShieldModule.Spawn implement if (clientMode) { // no ip filtering on the client bind(N2NNettyUpstreamHandler.class).toProvider(Providers.of(null)); - bind(TransportFilter.class).toInstance(TransportFilter.NOOP); + bind(ServerTransportFilter.class).toInstance(ServerTransportFilter.NOOP); return; } - bind(TransportFilter.class).to(SecurityFilter.Transport.class).asEagerSingleton(); + bind(ServerTransportFilter.class).to(SecurityFilter.ServerTransport.class).asEagerSingleton(); if (settings.getAsBoolean("shield.transport.n2n.ip_filter.enabled", true)) { bind(IPFilteringN2NAuthenticator.class).asEagerSingleton(); bind(N2NNettyUpstreamHandler.class).asEagerSingleton(); diff --git a/src/main/java/org/elasticsearch/shield/transport/SecuredTransportService.java b/src/main/java/org/elasticsearch/shield/transport/SecuredTransportService.java index 033b3b477b1..72c70b2ff11 100644 --- a/src/main/java/org/elasticsearch/shield/transport/SecuredTransportService.java +++ b/src/main/java/org/elasticsearch/shield/transport/SecuredTransportService.java @@ -5,38 +5,24 @@ */ package org.elasticsearch.shield.transport; -import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.*; -import java.io.IOException; - /** * */ public class SecuredTransportService extends TransportService { - private final TransportFilter filter; + private final ServerTransportFilter filter; @Inject - public SecuredTransportService(Settings settings, Transport transport, ThreadPool threadPool, TransportFilter filter) { + public SecuredTransportService(Settings settings, Transport transport, ThreadPool threadPool, ServerTransportFilter filter) { super(settings, transport, threadPool); this.filter = filter; } - public void sendRequest(final DiscoveryNode node, final String action, final TransportRequest request, - final TransportRequestOptions options, TransportResponseHandler handler) { - try { - filter.outboundRequest(action, request); - } catch (Throwable t) { - handler.handleException(new TransportException("failed sending request", t)); - return; - } - super.sendRequest(node, action, request, options, new SecuredResponseHandler(handler, filter)); - } - @Override public void registerHandler(String action, TransportRequestHandler handler) { super.registerHandler(action, new SecuredRequestHandler(action, handler, filter)); @@ -46,9 +32,9 @@ public class SecuredTransportService extends TransportService { private final String action; private final TransportRequestHandler handler; - private final TransportFilter filter; + private final ServerTransportFilter filter; - SecuredRequestHandler(String action, TransportRequestHandler handler, TransportFilter filter) { + SecuredRequestHandler(String action, TransportRequestHandler handler, ServerTransportFilter filter) { this.action = action; this.handler = handler; this.filter = filter; @@ -62,12 +48,12 @@ public class SecuredTransportService extends TransportService { @Override @SuppressWarnings("unchecked") public void messageReceived(TransportRequest request, TransportChannel channel) throws Exception { try { - filter.inboundRequest(action, request); + filter.inbound(action, request); } catch (Throwable t) { channel.sendResponse(t); return; } - handler.messageReceived(request, new SecuredTransportChannel(channel, filter)); + handler.messageReceived(request, channel); } @Override @@ -80,86 +66,4 @@ public class SecuredTransportService extends TransportService { return handler.isForceExecution(); } } - - static class SecuredResponseHandler implements TransportResponseHandler { - - private final TransportResponseHandler handler; - private final TransportFilter filter; - - SecuredResponseHandler(TransportResponseHandler handler, TransportFilter filter) { - this.handler = handler; - this.filter = filter; - } - - @Override - public TransportResponse newInstance() { - return handler.newInstance(); - } - - @Override @SuppressWarnings("unchecked") - public void handleResponse(TransportResponse response) { - try { - filter.inboundResponse(response); - } catch (Throwable t) { - handleException(new TransportException("response received but rejected locally", t)); - return; - } - handler.handleResponse(response); - } - - @Override - public void handleException(TransportException exp) { - handler.handleException(exp); - } - - @Override - public String executor() { - return handler.executor(); - } - } - - static class SecuredTransportChannel implements TransportChannel { - - private final TransportChannel channel; - private final TransportFilter filter; - - SecuredTransportChannel(TransportChannel channel, TransportFilter filter) { - this.channel = channel; - this.filter = filter; - } - - @Override - public String action() { - return channel.action(); - } - - @Override - public void sendResponse(TransportResponse response) throws IOException { - if (filter(response)) { - channel.sendResponse(response); - } - } - - @Override - public void sendResponse(TransportResponse response, TransportResponseOptions options) throws IOException { - if (filter(response)) { - channel.sendResponse(response, options); - } - } - - private boolean filter(TransportResponse response) throws IOException { - try { - filter.outboundResponse(channel.action(), response); - } catch (Throwable t) { - channel.sendResponse(t); - return false; - } - return true; - } - - @Override - public void sendResponse(Throwable error) throws IOException { - channel.sendResponse(error); - } - } } diff --git a/src/main/java/org/elasticsearch/shield/transport/ServerTransportFilter.java b/src/main/java/org/elasticsearch/shield/transport/ServerTransportFilter.java new file mode 100644 index 00000000000..d9b50caa2fd --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/transport/ServerTransportFilter.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.transport; + +import org.elasticsearch.transport.TransportRequest; + +public interface ServerTransportFilter { + + static final ServerTransportFilter NOOP = new ServerTransportFilter() { + @Override + public void inbound(String action, TransportRequest request) {} + }; + + /** + * Called just after the given request was received by the transport. Any exception + * thrown by this method will stop the request from being handled and the error will + * be sent back to the sender. + */ + void inbound(String action, TransportRequest request); + +} diff --git a/src/main/java/org/elasticsearch/shield/transport/TransportFilter.java b/src/main/java/org/elasticsearch/shield/transport/TransportFilter.java deleted file mode 100644 index 23363343683..00000000000 --- a/src/main/java/org/elasticsearch/shield/transport/TransportFilter.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License; - * you may not use this file except in compliance with the Elastic License. - */ -package org.elasticsearch.shield.transport; - -import org.elasticsearch.transport.TransportRequest; -import org.elasticsearch.transport.TransportResponse; - -/** - * - */ -public interface TransportFilter { - - static final TransportFilter NOOP = new Base(); - - /** - * Called just before the given request is about to be sent. Any exception thrown - * by this method will stop the request from being sent. - */ - void outboundRequest(String action, TransportRequest request); - - /** - * Called just after the given request was received by the transport. Any exception - * thrown by this method will stop the request from being handled and the error will - * be sent back to the sender. - */ - void inboundRequest(String action, TransportRequest request); - - /** - * Called just before the given response is about to be sent. Any exception thrown - * by this method will stop the response from being sent and an error will be sent - * instead. - */ - void outboundResponse(String action, TransportResponse response); - - /** - * Called just after the given response was received by the transport. Any exception - * thrown by this method will stop the response from being handled normally and instead - * the error will be used as the response. - */ - void inboundResponse(TransportResponse response); - - static class Base implements TransportFilter { - - @Override - public void outboundRequest(String action, TransportRequest request) { - } - - @Override - public void inboundRequest(String action, TransportRequest request) { - } - - @Override - public void outboundResponse(String action, TransportResponse response) { - } - - @Override - public void inboundResponse(TransportResponse response) { - } - } - -} diff --git a/src/test/java/org/elasticsearch/shield/SecurityFilterTests.java b/src/test/java/org/elasticsearch/shield/SecurityFilterTests.java index fd641911525..2be3481174a 100644 --- a/src/test/java/org/elasticsearch/shield/SecurityFilterTests.java +++ b/src/test/java/org/elasticsearch/shield/SecurityFilterTests.java @@ -137,9 +137,9 @@ public class SecurityFilterTests extends ElasticsearchTestCase { @Test public void testTransport_InboundRequest() throws Exception { filter = mock(SecurityFilter.class); - SecurityFilter.Transport transport = new SecurityFilter.Transport(filter); + SecurityFilter.ServerTransport transport = new SecurityFilter.ServerTransport(filter); InternalRequest request = new InternalRequest(); - transport.inboundRequest("_action", request); + transport.inbound("_action", request); verify(filter).authenticateAndAuthorize("_action", request); } @@ -148,10 +148,10 @@ public class SecurityFilterTests extends ElasticsearchTestCase { thrown.expect(RuntimeException.class); thrown.expectMessage("process-error"); filter = mock(SecurityFilter.class); - SecurityFilter.Transport transport = new SecurityFilter.Transport(filter); + SecurityFilter.ServerTransport transport = new SecurityFilter.ServerTransport(filter); InternalRequest request = new InternalRequest(); doThrow(new RuntimeException("process-error")).when(filter).authenticateAndAuthorize("_action", request); - transport.inboundRequest("_action", request); + transport.inbound("_action", request); } @Test diff --git a/src/test/java/org/elasticsearch/shield/transport/TransportFilterTests.java b/src/test/java/org/elasticsearch/shield/transport/TransportFilterTests.java index 4251eba9735..aa77c52544f 100644 --- a/src/test/java/org/elasticsearch/shield/transport/TransportFilterTests.java +++ b/src/test/java/org/elasticsearch/shield/transport/TransportFilterTests.java @@ -68,17 +68,11 @@ public class TransportFilterTests extends ElasticsearchIntegrationTest { await(latch); - TransportFilter sourceFilter = internalCluster().getInstance(TransportFilter.class, source); - TransportFilter targetFilter = internalCluster().getInstance(TransportFilter.class, target); + ServerTransportFilter sourceFilter = internalCluster().getInstance(ServerTransportFilter.class, source); + ServerTransportFilter targetFilter = internalCluster().getInstance(ServerTransportFilter.class, target); InOrder inOrder = inOrder(sourceFilter, targetFilter); - inOrder.verify(sourceFilter).outboundRequest("_action", new Request("src_to_trgt")); - inOrder.verify(targetFilter).inboundRequest("_action", new Request("src_to_trgt")); - inOrder.verify(targetFilter).outboundResponse("_action", new Response("trgt_to_src")); - inOrder.verify(sourceFilter).inboundResponse(new Response("trgt_to_src")); - inOrder.verify(targetFilter).outboundRequest("_action", new Request("trgt_to_src")); - inOrder.verify(sourceFilter).inboundRequest("_action", new Request("trgt_to_src")); - inOrder.verify(sourceFilter).outboundResponse("_action", new Response("src_to_trgt")); - inOrder.verify(targetFilter).inboundResponse(new Response("src_to_trgt")); + inOrder.verify(targetFilter).inbound("_action", new Request("src_to_trgt")); + inOrder.verify(sourceFilter).inbound("_action", new Request("trgt_to_src")); } public static class InternalPlugin extends AbstractPlugin { @@ -102,7 +96,7 @@ public class TransportFilterTests extends ElasticsearchIntegrationTest { public static class TestTransportFilterModule extends AbstractModule { @Override protected void configure() { - bind(TransportFilter.class).toInstance(mock(TransportFilter.class)); + bind(ServerTransportFilter.class).toInstance(mock(ServerTransportFilter.class)); } } diff --git a/src/test/java/org/elasticsearch/shield/transport/n2n/IPFilteringN2NAuthenticatorTests.java b/src/test/java/org/elasticsearch/shield/transport/n2n/IPFilteringN2NAuthenticatorTests.java index 155606ff850..bb39e0a05aa 100644 --- a/src/test/java/org/elasticsearch/shield/transport/n2n/IPFilteringN2NAuthenticatorTests.java +++ b/src/test/java/org/elasticsearch/shield/transport/n2n/IPFilteringN2NAuthenticatorTests.java @@ -13,6 +13,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.env.Environment; import org.elasticsearch.test.ElasticsearchTestCase; +import org.elasticsearch.test.junit.annotations.Network; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.watcher.ResourceWatcherService; import org.junit.After; @@ -81,6 +82,7 @@ public class IPFilteringN2NAuthenticatorTests extends ElasticsearchTestCase { } @Test + @Network // requires network for name resolution public void testThatHostnamesCanBeProcessed() throws Exception { writeConfigFile("allow: localhost\ndeny: '*.google.com'");