remove NetworkUtils and InetAddress getLocalHost usage in shield

Original commit: elastic/x-pack-elasticsearch@460ef63824
This commit is contained in:
jaymode 2015-08-18 13:30:56 -04:00
parent 4756f07f2b
commit 152aeaa776
9 changed files with 119 additions and 98 deletions

View File

@ -27,8 +27,6 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.component.AbstractComponent; import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.inject.Provider; import org.elasticsearch.common.inject.Provider;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.common.transport.TransportAddress;
@ -48,6 +46,7 @@ import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.authz.Privilege; import org.elasticsearch.shield.authz.Privilege;
import org.elasticsearch.shield.rest.RemoteHostHeader; import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule; import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage; import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
import org.joda.time.DateTime; import org.joda.time.DateTime;
@ -106,6 +105,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
private final Environment environment; private final Environment environment;
private final LinkedBlockingQueue<Message> eventQueue; private final LinkedBlockingQueue<Message> eventQueue;
private final QueueConsumer queueConsumer; private final QueueConsumer queueConsumer;
private final Transport transport;
private final boolean indexToRemoteCluster; private final boolean indexToRemoteCluster;
private BulkProcessor bulkProcessor; private BulkProcessor bulkProcessor;
@ -123,12 +123,13 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
@Inject @Inject
public IndexAuditTrail(Settings settings, IndexAuditUserHolder indexingAuditUser, public IndexAuditTrail(Settings settings, IndexAuditUserHolder indexingAuditUser,
Environment environment, AuthenticationService authenticationService, Environment environment, AuthenticationService authenticationService,
Provider<Client> clientProvider) { Transport transport, Provider<Client> clientProvider) {
super(settings); super(settings);
this.auditUser = indexingAuditUser; this.auditUser = indexingAuditUser;
this.authenticationService = authenticationService; this.authenticationService = authenticationService;
this.clientProvider = clientProvider; this.clientProvider = clientProvider;
this.environment = environment; this.environment = environment;
this.transport = transport;
this.nodeName = settings.get("name"); this.nodeName = settings.get("name");
this.queueConsumer = new QueueConsumer(EsExecutors.threadName(settings, "audit-queue-consumer")); this.queueConsumer = new QueueConsumer(EsExecutors.threadName(settings, "audit-queue-consumer"));
@ -239,8 +240,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
*/ */
public void start(boolean master) { public void start(boolean master) {
if (state.compareAndSet(State.INITIALIZED, State.STARTING)) { if (state.compareAndSet(State.INITIALIZED, State.STARTING)) {
this.nodeHostName = NetworkUtils.getLocalHost().getHostName(); this.nodeHostName = transport.boundAddress().publishAddress().getHost();
this.nodeHostAddress = NetworkUtils.getLocalHost().getHostAddress(); this.nodeHostAddress = transport.boundAddress().publishAddress().getAddress();
if (client == null) { if (client == null) {
initializeClient(); initializeClient();
@ -461,7 +462,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
Message msg = new Message().start(); Message msg = new Message().start();
common("transport", type, msg.builder); common("transport", type, msg.builder);
originAttributes(message, msg.builder); originAttributes(message, msg.builder, transport);
if (action != null) { if (action != null) {
msg.builder.field(Field.ACTION, action); msg.builder.field(Field.ACTION, action);
@ -535,7 +536,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
return builder; return builder;
} }
private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder) throws IOException { private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder, Transport transport) throws IOException {
// first checking if the message originated in a rest call // first checking if the message originated in a rest call
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message); InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message);
@ -559,7 +560,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
// the call was originated locally on this node // the call was originated locally on this node
builder.field(Field.ORIGIN_TYPE, "local_node"); builder.field(Field.ORIGIN_TYPE, "local_node");
builder.field(Field.ORIGIN_ADDRESS, NetworkUtils.getLocalHost().getHostAddress()); builder.field(Field.ORIGIN_ADDRESS, transport.boundAddress().publishAddress().getAddress());
return builder; return builder;
} }

View File

@ -19,6 +19,7 @@ import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.authz.Privilege; import org.elasticsearch.shield.authz.Privilege;
import org.elasticsearch.shield.rest.RemoteHostHeader; import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule; import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage; import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
@ -38,6 +39,7 @@ public class LoggingAuditTrail implements AuditTrail {
private final String prefix; private final String prefix;
private final ESLogger logger; private final ESLogger logger;
private final Transport transport;
@Override @Override
public String name() { public String name() {
@ -45,17 +47,18 @@ public class LoggingAuditTrail implements AuditTrail {
} }
@Inject @Inject
public LoggingAuditTrail(Settings settings) { public LoggingAuditTrail(Settings settings, Transport transport) {
this(resolvePrefix(settings), Loggers.getLogger(LoggingAuditTrail.class)); this(resolvePrefix(settings, transport), transport, Loggers.getLogger(LoggingAuditTrail.class));
} }
LoggingAuditTrail(Settings settings, ESLogger logger) { LoggingAuditTrail(Settings settings, Transport transport, ESLogger logger) {
this(resolvePrefix(settings), logger); this(resolvePrefix(settings, transport), transport, logger);
} }
LoggingAuditTrail(String prefix, ESLogger logger) { LoggingAuditTrail(String prefix, Transport transport, ESLogger logger) {
this.logger = logger; this.logger = logger;
this.prefix = prefix; this.prefix = prefix;
this.transport = transport;
} }
@Override @Override
@ -63,15 +66,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), action, indices, message.getClass().getSimpleName()); logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message), action, indices); logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, transport), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message), action, message.getClass().getSimpleName()); logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, transport), action, message.getClass().getSimpleName());
} else { } else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", prefix, originAttributes(message), action); logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", prefix, originAttributes(message, transport), action);
} }
} }
} }
@ -90,15 +93,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), token.principal(), action, indices, message.getClass().getSimpleName()); logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), token.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), token.principal(), action, indices); logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), token.principal(), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), token.principal(), action, message.getClass().getSimpleName()); logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), token.principal(), action, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), token.principal(), action); logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), token.principal(), action);
} }
} }
} }
@ -117,15 +120,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), action, indices, message.getClass().getSimpleName()); logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message), action, indices); logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, transport), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message), action, message.getClass().getSimpleName()); logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, transport), action, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", prefix, originAttributes(message), action); logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", prefix, originAttributes(message, transport), action);
} }
} }
} }
@ -144,9 +147,9 @@ public class LoggingAuditTrail implements AuditTrail {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, realm, originAttributes(message), token.principal(), action, indices, message.getClass().getSimpleName()); logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, realm, originAttributes(message, transport), token.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]", prefix, realm, originAttributes(message), token.principal(), action, message.getClass().getSimpleName()); logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]", prefix, realm, originAttributes(message, transport), token.principal(), action, message.getClass().getSimpleName());
} }
} }
} }
@ -166,9 +169,9 @@ public class LoggingAuditTrail implements AuditTrail {
if (user.isSystem() && Privilege.SYSTEM.predicate().apply(action)) { if (user.isSystem() && Privilege.SYSTEM.predicate().apply(action)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
if (indices != null) { if (indices != null) {
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName()); logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName()); logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
} }
} }
return; return;
@ -176,15 +179,15 @@ public class LoggingAuditTrail implements AuditTrail {
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName()); logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), user.principal(), action, indices); logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName()); logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), user.principal(), action); logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), user.principal(), action);
} }
} }
} }
@ -194,15 +197,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName()); logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), user.principal(), action, indices); logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName()); logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), user.principal(), action); logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), user.principal(), action);
} }
} }
} }
@ -241,7 +244,7 @@ public class LoggingAuditTrail implements AuditTrail {
return "origin_address=[" + request.getRemoteAddress() + "]"; return "origin_address=[" + request.getRemoteAddress() + "]";
} }
static String originAttributes(TransportMessage message) { static String originAttributes(TransportMessage message, Transport transport) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
// first checking if the message originated in a rest call // first checking if the message originated in a rest call
@ -265,21 +268,21 @@ public class LoggingAuditTrail implements AuditTrail {
// the call was originated locally on this node // the call was originated locally on this node
return builder.append("origin_type=[local_node], origin_address=[") return builder.append("origin_type=[local_node], origin_address=[")
.append(NetworkUtils.getLocalHost().getHostAddress()) .append(transport.boundAddress().publishAddress().getAddress())
.append("]") .append("]")
.toString(); .toString();
} }
static String resolvePrefix(Settings settings) { static String resolvePrefix(Settings settings, Transport transport) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_address", false)) { if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_address", false)) {
String address = NetworkUtils.getLocalHost().getHostAddress(); String address = transport.boundAddress().publishAddress().getAddress();
if (address != null) { if (address != null) {
builder.append("[").append(address).append("] "); builder.append("[").append(address).append("] ");
} }
} }
if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_name", false)) { if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_name", false)) {
String hostName = NetworkUtils.getLocalHost().getHostAddress(); String hostName = transport.boundAddress().publishAddress().getHost();
if (hostName != null) { if (hostName != null) {
builder.append("[").append(hostName).append("] "); builder.append("[").append(hostName).append("] ");
} }

View File

@ -5,12 +5,17 @@
*/ */
package org.elasticsearch.shield.audit; package org.elasticsearch.shield.audit;
import org.elasticsearch.Version;
import org.elasticsearch.common.inject.Guice; import org.elasticsearch.common.inject.Guice;
import org.elasticsearch.common.inject.Injector; import org.elasticsearch.common.inject.Injector;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.common.settings.SettingsModule;
import org.elasticsearch.indices.breaker.CircuitBreakerModule;
import org.elasticsearch.shield.audit.logfile.LoggingAuditTrail; import org.elasticsearch.shield.audit.logfile.LoggingAuditTrail;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPoolModule;
import org.elasticsearch.transport.TransportModule;
import org.junit.Test; import org.junit.Test;
import static org.hamcrest.Matchers.*; import static org.hamcrest.Matchers.*;
@ -46,13 +51,18 @@ public class AuditTrailModuleTests extends ESTestCase {
.put("shield.audit.enabled", true) .put("shield.audit.enabled", true)
.put("client.type", "node") .put("client.type", "node")
.build(); .build();
Injector injector = Guice.createInjector(new SettingsModule(settings), new AuditTrailModule(settings)); ThreadPool pool = new ThreadPool("testLogFile");
AuditTrail auditTrail = injector.getInstance(AuditTrail.class); try {
assertThat(auditTrail, instanceOf(AuditTrailService.class)); Injector injector = Guice.createInjector(new SettingsModule(settings), new AuditTrailModule(settings), new TransportModule(settings), new CircuitBreakerModule(settings), new ThreadPoolModule(pool), new Version.Module(Version.CURRENT));
AuditTrailService service = (AuditTrailService) auditTrail; AuditTrail auditTrail = injector.getInstance(AuditTrail.class);
assertThat(service.auditTrails, notNullValue()); assertThat(auditTrail, instanceOf(AuditTrailService.class));
assertThat(service.auditTrails.length, is(1)); AuditTrailService service = (AuditTrailService) auditTrail;
assertThat(service.auditTrails[0], instanceOf(LoggingAuditTrail.class)); assertThat(service.auditTrails, notNullValue());
assertThat(service.auditTrails.length, is(1));
assertThat(service.auditTrails[0], instanceOf(LoggingAuditTrail.class));
} finally {
pool.shutdown();
}
} }
@Test @Test

View File

@ -124,7 +124,7 @@ public class AuditTrailServiceTests extends ESTestCase {
@Test @Test
public void testConnectionGranted() throws Exception { public void testConnectionGranted() throws Exception {
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = randomBoolean() ? ShieldIpFilterRule.ACCEPT_ALL : IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; ShieldIpFilterRule rule = randomBoolean() ? ShieldIpFilterRule.ACCEPT_ALL : IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
service.connectionGranted(inetAddress, "client", rule); service.connectionGranted(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) { for (AuditTrail auditTrail : auditTrails) {
@ -134,7 +134,7 @@ public class AuditTrailServiceTests extends ESTestCase {
@Test @Test
public void testConnectionDenied() throws Exception { public void testConnectionDenied() throws Exception {
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
service.connectionDenied(inetAddress, "client", rule); service.connectionDenied(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) { for (AuditTrail auditTrail : auditTrails) {

View File

@ -14,8 +14,9 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.util.Providers; import org.elasticsearch.common.inject.util.Providers;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.DummyTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.transport.LocalTransportAddress; import org.elasticsearch.common.transport.LocalTransportAddress;
import org.elasticsearch.env.Environment; import org.elasticsearch.env.Environment;
@ -32,6 +33,7 @@ import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster; import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.test.ShieldIntegTestCase; import org.elasticsearch.test.ShieldIntegTestCase;
import org.elasticsearch.test.ShieldSettingsSource; import org.elasticsearch.test.ShieldSettingsSource;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportInfo; import org.elasticsearch.transport.TransportInfo;
import org.elasticsearch.transport.TransportMessage; import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest; import org.elasticsearch.transport.TransportRequest;
@ -175,9 +177,11 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
logger.info("--> settings: [{}]", settings.getAsMap().toString()); logger.info("--> settings: [{}]", settings.getAsMap().toString());
when(authService.authenticate(mock(RestRequest.class))).thenThrow(new UnsupportedOperationException("")); when(authService.authenticate(mock(RestRequest.class))).thenThrow(new UnsupportedOperationException(""));
when(authService.authenticate("_action", new LocalHostMockMessage(), user.user())).thenThrow(new UnsupportedOperationException("")); when(authService.authenticate("_action", new LocalHostMockMessage(), user.user())).thenThrow(new UnsupportedOperationException(""));
Transport transport = mock(Transport.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(DummyTransportAddress.INSTANCE, DummyTransportAddress.INSTANCE));
Environment env = new Environment(settings); Environment env = new Environment(settings);
auditor = new IndexAuditTrail(settings, user, env, authService, Providers.of(client())); auditor = new IndexAuditTrail(settings, user, env, authService, transport, Providers.of(client()));
auditor.start(true); auditor.start(true);
} }
@ -536,7 +540,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
public void testConnectionGranted() throws Exception { public void testConnectionGranted() throws Exception {
initialize(); initialize();
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditor.connectionGranted(inetAddress, "default", rule); auditor.connectionGranted(inetAddress, "default", rule);
awaitIndexCreation(resolveIndexName()); awaitIndexCreation(resolveIndexName());
@ -551,7 +555,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
@Test(expected = IndexNotFoundException.class) @Test(expected = IndexNotFoundException.class)
public void testConnectionGranted_Muted() throws Exception { public void testConnectionGranted_Muted() throws Exception {
initialize("connection_granted"); initialize("connection_granted");
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditor.connectionGranted(inetAddress, "default", rule); auditor.connectionGranted(inetAddress, "default", rule);
getClient().prepareExists(resolveIndexName()).execute().actionGet(); getClient().prepareExists(resolveIndexName()).execute().actionGet();
@ -561,7 +565,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
public void testConnectionDenied() throws Exception { public void testConnectionDenied() throws Exception {
initialize(); initialize();
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditor.connectionDenied(inetAddress, "default", rule); auditor.connectionDenied(inetAddress, "default", rule);
awaitIndexCreation(resolveIndexName()); awaitIndexCreation(resolveIndexName());
@ -576,7 +580,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
@Test(expected = IndexNotFoundException.class) @Test(expected = IndexNotFoundException.class)
public void testConnectionDenied_Muted() throws Exception { public void testConnectionDenied_Muted() throws Exception {
initialize("connection_denied"); initialize("connection_denied");
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditor.connectionDenied(inetAddress, "default", rule); auditor.connectionDenied(inetAddress, "default", rule);
getClient().prepareExists(resolveIndexName()).execute().actionGet(); getClient().prepareExists(resolveIndexName()).execute().actionGet();
@ -587,8 +591,8 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
DateTime dateTime = ISODateTimeFormat.dateTimeParser().withZoneUTC().parseDateTime((String) hit.field("@timestamp").getValue()); DateTime dateTime = ISODateTimeFormat.dateTimeParser().withZoneUTC().parseDateTime((String) hit.field("@timestamp").getValue());
assertThat(dateTime.isBefore(DateTime.now(DateTimeZone.UTC)), is(true)); assertThat(dateTime.isBefore(DateTime.now(DateTimeZone.UTC)), is(true));
assertThat(NetworkUtils.getLocalHost().getHostName(), equalTo(hit.field("node_host_name").getValue())); assertThat(DummyTransportAddress.INSTANCE.getHost(), equalTo(hit.field("node_host_name").getValue()));
assertThat(NetworkUtils.getLocalHost().getHostAddress(), equalTo(hit.field("node_host_address").getValue())); assertThat(DummyTransportAddress.INSTANCE.getAddress(), equalTo(hit.field("node_host_address").getValue()));
assertEquals(layer, hit.field("layer").getValue()); assertEquals(layer, hit.field("layer").getValue());
assertEquals(type, hit.field("event_type").getValue()); assertEquals(type, hit.field("event_type").getValue());
@ -602,13 +606,13 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
private static class RemoteHostMockMessage extends TransportMessage<RemoteHostMockMessage> { private static class RemoteHostMockMessage extends TransportMessage<RemoteHostMockMessage> {
RemoteHostMockMessage() throws Exception { RemoteHostMockMessage() throws Exception {
remoteAddress(new InetSocketTransportAddress(InetAddress.getLocalHost(), 1234)); remoteAddress(DummyTransportAddress.INSTANCE);
} }
} }
private static class RemoteHostMockTransportRequest extends TransportRequest { private static class RemoteHostMockTransportRequest extends TransportRequest {
RemoteHostMockTransportRequest() throws Exception { RemoteHostMockTransportRequest() throws Exception {
remoteAddress(new InetSocketTransportAddress(InetAddress.getLocalHost(), 1234)); remoteAddress(DummyTransportAddress.INSTANCE);
} }
} }
@ -710,7 +714,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
} }
static String remoteHostAddress() throws Exception { static String remoteHostAddress() throws Exception {
return InetAddress.getLocalHost().getHostAddress(); return DummyTransportAddress.INSTANCE.toString();
} }
} }

View File

@ -11,9 +11,7 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.network.NetworkUtils; import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.transport.*;
import org.elasticsearch.common.transport.LocalTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.User; import org.elasticsearch.shield.User;
import org.elasticsearch.shield.authc.AuthenticationToken; import org.elasticsearch.shield.authc.AuthenticationToken;
@ -21,6 +19,7 @@ import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.IPFilter; import org.elasticsearch.shield.transport.filter.IPFilter;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule; import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage; import org.elasticsearch.transport.TransportMessage;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -99,6 +98,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
private String prefix; private String prefix;
private Settings settings; private Settings settings;
private Transport transport;
@Before @Before
public void init() throws Exception { public void init() throws Exception {
@ -107,16 +107,18 @@ public class LoggingAuditTrailTests extends ESTestCase {
.put("shield.audit.logfile.prefix.emit_node_host_name", randomBoolean()) .put("shield.audit.logfile.prefix.emit_node_host_name", randomBoolean())
.put("shield.audit.logfile.prefix.emit_node_name", randomBoolean()) .put("shield.audit.logfile.prefix.emit_node_name", randomBoolean())
.build(); .build();
prefix = LoggingAuditTrail.resolvePrefix(settings); transport = mock(Transport.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(DummyTransportAddress.INSTANCE, DummyTransportAddress.INSTANCE));
prefix = LoggingAuditTrail.resolvePrefix(settings, transport);
} }
@Test @Test
public void testAnonymousAccessDenied_Transport() throws Exception { public void testAnonymousAccessDenied_Transport() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);
auditTrail.anonymousAccessDenied("_action", message); auditTrail.anonymousAccessDenied("_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -150,7 +152,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.anonymousAccessDenied(request); auditTrail.anonymousAccessDenied(request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -171,9 +173,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed() throws Exception { public void testAuthenticationFailed() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed(new MockToken(), "_action", message); auditTrail.authenticationFailed(new MockToken(), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -200,9 +202,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed_NoToken() throws Exception { public void testAuthenticationFailed_NoToken() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed("_action", message); auditTrail.authenticationFailed("_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -233,7 +235,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed(new MockToken(), request); auditTrail.authenticationFailed(new MockToken(), request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -256,7 +258,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed(request); auditTrail.authenticationFailed(request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -275,9 +277,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed_Realm() throws Exception { public void testAuthenticationFailed_Realm() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message); auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -304,7 +306,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed("_realm", new MockToken(), request); auditTrail.authenticationFailed("_realm", new MockToken(), request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -323,9 +325,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted() throws Exception { public void testAccessGranted() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(new User.Simple("_username", "r1"), "_action", message); auditTrail.accessGranted(new User.Simple("_username", "r1"), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -354,9 +356,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted_InternalSystemAction() throws Exception { public void testAccessGranted_InternalSystemAction() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(User.SYSTEM, "internal:_action", message); auditTrail.accessGranted(User.SYSTEM, "internal:_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -379,9 +381,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted_InternalSystemAction_NonSystemUser() throws Exception { public void testAccessGranted_InternalSystemAction_NonSystemUser() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(new User.Simple("_username"), "internal:_action", message); auditTrail.accessGranted(new User.Simple("_username"), "internal:_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -410,9 +412,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessDenied() throws Exception { public void testAccessDenied() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message); String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessDenied(new User.Simple("_username", "r1"), "_action", message); auditTrail.accessDenied(new User.Simple("_username", "r1"), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -439,8 +441,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testConnectionDenied() throws Exception { public void testConnectionDenied() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditTrail.connectionDenied(inetAddress, "default", rule); auditTrail.connectionDenied(inetAddress, "default", rule);
switch (level) { switch (level) {
@ -460,8 +462,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testConnectionGranted() throws Exception { public void testConnectionGranted() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
InetAddress inetAddress = InetAddress.getLocalHost(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditTrail.connectionGranted(inetAddress, "default", rule); auditTrail.connectionGranted(inetAddress, "default", rule);
switch (level) { switch (level) {
@ -482,7 +484,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
@Test @Test
public void testOriginAttributes() throws Exception { public void testOriginAttributes() throws Exception {
MockMessage message = new MockMessage(); MockMessage message = new MockMessage();
String text = LoggingAuditTrail.originAttributes(message); String text = LoggingAuditTrail.originAttributes(message, transport);;
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message); InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message);
if (restAddress != null) { if (restAddress != null) {
assertThat(text, equalTo("origin_type=[rest], origin_address=[" + restAddress + "]")); assertThat(text, equalTo("origin_type=[rest], origin_address=[" + restAddress + "]"));
@ -490,7 +492,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
} }
TransportAddress address = message.remoteAddress(); TransportAddress address = message.remoteAddress();
if (address == null) { if (address == null) {
assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + NetworkUtils.getLocalHost().getHostAddress() + "]")); assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + transport.boundAddress().publishAddress().getAddress() + "]"));
return; return;
} }

View File

@ -7,7 +7,6 @@ package org.elasticsearch.shield.transport.filter;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.transport.InetSocketTransportAddress;
@ -45,12 +44,12 @@ public class IPFilterTests extends ESTestCase {
nodeSettingsService = mock(NodeSettingsService.class); nodeSettingsService = mock(NodeSettingsService.class);
httpTransport = mock(HttpServerTransport.class); httpTransport = mock(HttpServerTransport.class);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9200); InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(InetAddress.getLoopbackAddress().getHostAddress(), 9200);
when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress)); when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress));
when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
transport = mock(Transport.class); transport = mock(Transport.class);
InetSocketTransportAddress address = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9300); InetSocketTransportAddress address = new InetSocketTransportAddress(InetAddress.getLoopbackAddress().getHostAddress(), 9300);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address)); when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address));
when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
} }

View File

@ -7,7 +7,6 @@ package org.elasticsearch.shield.transport.netty;
import com.google.common.net.InetAddresses; import com.google.common.net.InetAddresses;
import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress; import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress; import org.elasticsearch.common.transport.InetSocketTransportAddress;
@ -21,6 +20,7 @@ import org.jboss.netty.channel.*;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.net.InetAddress;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.SocketAddress; import java.net.SocketAddress;
@ -46,7 +46,7 @@ public class IPFilterNettyUpstreamHandlerTests extends ESTestCase {
boolean isHttpEnabled = randomBoolean(); boolean isHttpEnabled = randomBoolean();
Transport transport = mock(Transport.class); Transport transport = mock(Transport.class);
InetSocketTransportAddress address = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9300); InetSocketTransportAddress address = new InetSocketTransportAddress(InetAddress.getLoopbackAddress(), 9300);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address)); when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address));
when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
@ -55,7 +55,7 @@ public class IPFilterNettyUpstreamHandlerTests extends ESTestCase {
if (isHttpEnabled) { if (isHttpEnabled) {
HttpServerTransport httpTransport = mock(HttpServerTransport.class); HttpServerTransport httpTransport = mock(HttpServerTransport.class);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9200); InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(InetAddress.getLoopbackAddress(), 9200);
when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress)); when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress));
when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
ipFilter.setHttpServerTransport(httpTransport); ipFilter.setHttpServerTransport(httpTransport);

View File

@ -5,6 +5,7 @@
*/ */
package org.elasticsearch.shield.transport.netty; package org.elasticsearch.shield.transport.netty;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ShieldIntegTestCase; import org.elasticsearch.test.ShieldIntegTestCase;
@ -16,6 +17,7 @@ import java.nio.file.Path;
import static org.elasticsearch.common.settings.Settings.settingsBuilder; import static org.elasticsearch.common.settings.Settings.settingsBuilder;
import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.CoreMatchers.is;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/x-plugins/issues/468")
public class IPHostnameVerificationTests extends ShieldIntegTestCase { public class IPHostnameVerificationTests extends ShieldIntegTestCase {
Path keystore; Path keystore;