remove NetworkUtils and InetAddress getLocalHost usage in shield

Original commit: elastic/x-pack-elasticsearch@460ef63824
This commit is contained in:
jaymode 2015-08-18 13:30:56 -04:00
parent 4756f07f2b
commit 152aeaa776
9 changed files with 119 additions and 98 deletions

View File

@ -27,8 +27,6 @@ import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.inject.Provider;
import org.elasticsearch.common.io.Streams;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
@ -48,6 +46,7 @@ import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.authz.Privilege;
import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest;
import org.joda.time.DateTime;
@ -106,6 +105,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
private final Environment environment;
private final LinkedBlockingQueue<Message> eventQueue;
private final QueueConsumer queueConsumer;
private final Transport transport;
private final boolean indexToRemoteCluster;
private BulkProcessor bulkProcessor;
@ -123,12 +123,13 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
@Inject
public IndexAuditTrail(Settings settings, IndexAuditUserHolder indexingAuditUser,
Environment environment, AuthenticationService authenticationService,
Provider<Client> clientProvider) {
Transport transport, Provider<Client> clientProvider) {
super(settings);
this.auditUser = indexingAuditUser;
this.authenticationService = authenticationService;
this.clientProvider = clientProvider;
this.environment = environment;
this.transport = transport;
this.nodeName = settings.get("name");
this.queueConsumer = new QueueConsumer(EsExecutors.threadName(settings, "audit-queue-consumer"));
@ -239,8 +240,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
*/
public void start(boolean master) {
if (state.compareAndSet(State.INITIALIZED, State.STARTING)) {
this.nodeHostName = NetworkUtils.getLocalHost().getHostName();
this.nodeHostAddress = NetworkUtils.getLocalHost().getHostAddress();
this.nodeHostName = transport.boundAddress().publishAddress().getHost();
this.nodeHostAddress = transport.boundAddress().publishAddress().getAddress();
if (client == null) {
initializeClient();
@ -461,7 +462,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
Message msg = new Message().start();
common("transport", type, msg.builder);
originAttributes(message, msg.builder);
originAttributes(message, msg.builder, transport);
if (action != null) {
msg.builder.field(Field.ACTION, action);
@ -535,7 +536,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
return builder;
}
private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder) throws IOException {
private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder, Transport transport) throws IOException {
// first checking if the message originated in a rest call
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message);
@ -559,7 +560,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail {
// the call was originated locally on this node
builder.field(Field.ORIGIN_TYPE, "local_node");
builder.field(Field.ORIGIN_ADDRESS, NetworkUtils.getLocalHost().getHostAddress());
builder.field(Field.ORIGIN_ADDRESS, transport.boundAddress().publishAddress().getAddress());
return builder;
}

View File

@ -19,6 +19,7 @@ import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.authz.Privilege;
import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest;
@ -38,6 +39,7 @@ public class LoggingAuditTrail implements AuditTrail {
private final String prefix;
private final ESLogger logger;
private final Transport transport;
@Override
public String name() {
@ -45,17 +47,18 @@ public class LoggingAuditTrail implements AuditTrail {
}
@Inject
public LoggingAuditTrail(Settings settings) {
this(resolvePrefix(settings), Loggers.getLogger(LoggingAuditTrail.class));
public LoggingAuditTrail(Settings settings, Transport transport) {
this(resolvePrefix(settings, transport), transport, Loggers.getLogger(LoggingAuditTrail.class));
}
LoggingAuditTrail(Settings settings, ESLogger logger) {
this(resolvePrefix(settings), logger);
LoggingAuditTrail(Settings settings, Transport transport, ESLogger logger) {
this(resolvePrefix(settings, transport), transport, logger);
}
LoggingAuditTrail(String prefix, ESLogger logger) {
LoggingAuditTrail(String prefix, Transport transport, ESLogger logger) {
this.logger = logger;
this.prefix = prefix;
this.transport = transport;
}
@Override
@ -63,15 +66,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message);
if (indices != null) {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), action, indices, message.getClass().getSimpleName());
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), action, indices, message.getClass().getSimpleName());
} else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message), action, indices);
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, transport), action, indices);
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message), action, message.getClass().getSimpleName());
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, transport), action, message.getClass().getSimpleName());
} else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", prefix, originAttributes(message), action);
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", prefix, originAttributes(message, transport), action);
}
}
}
@ -90,15 +93,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message);
if (indices != null) {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), token.principal(), action, indices, message.getClass().getSimpleName());
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), token.principal(), action, indices, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), token.principal(), action, indices);
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), token.principal(), action, indices);
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), token.principal(), action, message.getClass().getSimpleName());
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), token.principal(), action, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), token.principal(), action);
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), token.principal(), action);
}
}
}
@ -117,15 +120,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message);
if (indices != null) {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), action, indices, message.getClass().getSimpleName());
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), action, indices, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message), action, indices);
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, transport), action, indices);
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message), action, message.getClass().getSimpleName());
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, transport), action, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", prefix, originAttributes(message), action);
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", prefix, originAttributes(message, transport), action);
}
}
}
@ -144,9 +147,9 @@ public class LoggingAuditTrail implements AuditTrail {
if (logger.isTraceEnabled()) {
String indices = indicesString(message);
if (indices != null) {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, realm, originAttributes(message), token.principal(), action, indices, message.getClass().getSimpleName());
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, realm, originAttributes(message, transport), token.principal(), action, indices, message.getClass().getSimpleName());
} else {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]", prefix, realm, originAttributes(message), token.principal(), action, message.getClass().getSimpleName());
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]", prefix, realm, originAttributes(message, transport), token.principal(), action, message.getClass().getSimpleName());
}
}
}
@ -166,9 +169,9 @@ public class LoggingAuditTrail implements AuditTrail {
if (user.isSystem() && Privilege.SYSTEM.predicate().apply(action)) {
if (logger.isTraceEnabled()) {
if (indices != null) {
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName());
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else {
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName());
logger.trace("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
}
}
return;
@ -176,15 +179,15 @@ public class LoggingAuditTrail implements AuditTrail {
if (indices != null) {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName());
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else {
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), user.principal(), action, indices);
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices);
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName());
logger.debug("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
} else {
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), user.principal(), action);
logger.info("{}[transport] [access_granted]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), user.principal(), action);
}
}
}
@ -194,15 +197,15 @@ public class LoggingAuditTrail implements AuditTrail {
String indices = indicesString(message);
if (indices != null) {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, indices, message.getClass().getSimpleName());
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message), user.principal(), action, indices);
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, originAttributes(message, transport), user.principal(), action, indices);
}
} else {
if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message), user.principal(), action, message.getClass().getSimpleName());
logger.debug("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, originAttributes(message, transport), user.principal(), action, message.getClass().getSimpleName());
} else {
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message), user.principal(), action);
logger.error("{}[transport] [access_denied]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, transport), user.principal(), action);
}
}
}
@ -241,7 +244,7 @@ public class LoggingAuditTrail implements AuditTrail {
return "origin_address=[" + request.getRemoteAddress() + "]";
}
static String originAttributes(TransportMessage message) {
static String originAttributes(TransportMessage message, Transport transport) {
StringBuilder builder = new StringBuilder();
// first checking if the message originated in a rest call
@ -265,21 +268,21 @@ public class LoggingAuditTrail implements AuditTrail {
// the call was originated locally on this node
return builder.append("origin_type=[local_node], origin_address=[")
.append(NetworkUtils.getLocalHost().getHostAddress())
.append(transport.boundAddress().publishAddress().getAddress())
.append("]")
.toString();
}
static String resolvePrefix(Settings settings) {
static String resolvePrefix(Settings settings, Transport transport) {
StringBuilder builder = new StringBuilder();
if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_address", false)) {
String address = NetworkUtils.getLocalHost().getHostAddress();
String address = transport.boundAddress().publishAddress().getAddress();
if (address != null) {
builder.append("[").append(address).append("] ");
}
}
if (settings.getAsBoolean("shield.audit.logfile.prefix.emit_node_host_name", false)) {
String hostName = NetworkUtils.getLocalHost().getHostAddress();
String hostName = transport.boundAddress().publishAddress().getHost();
if (hostName != null) {
builder.append("[").append(hostName).append("] ");
}

View File

@ -5,12 +5,17 @@
*/
package org.elasticsearch.shield.audit;
import org.elasticsearch.Version;
import org.elasticsearch.common.inject.Guice;
import org.elasticsearch.common.inject.Injector;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsModule;
import org.elasticsearch.indices.breaker.CircuitBreakerModule;
import org.elasticsearch.shield.audit.logfile.LoggingAuditTrail;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.threadpool.ThreadPoolModule;
import org.elasticsearch.transport.TransportModule;
import org.junit.Test;
import static org.hamcrest.Matchers.*;
@ -46,13 +51,18 @@ public class AuditTrailModuleTests extends ESTestCase {
.put("shield.audit.enabled", true)
.put("client.type", "node")
.build();
Injector injector = Guice.createInjector(new SettingsModule(settings), new AuditTrailModule(settings));
ThreadPool pool = new ThreadPool("testLogFile");
try {
Injector injector = Guice.createInjector(new SettingsModule(settings), new AuditTrailModule(settings), new TransportModule(settings), new CircuitBreakerModule(settings), new ThreadPoolModule(pool), new Version.Module(Version.CURRENT));
AuditTrail auditTrail = injector.getInstance(AuditTrail.class);
assertThat(auditTrail, instanceOf(AuditTrailService.class));
AuditTrailService service = (AuditTrailService) auditTrail;
assertThat(service.auditTrails, notNullValue());
assertThat(service.auditTrails.length, is(1));
assertThat(service.auditTrails[0], instanceOf(LoggingAuditTrail.class));
} finally {
pool.shutdown();
}
}
@Test

View File

@ -124,7 +124,7 @@ public class AuditTrailServiceTests extends ESTestCase {
@Test
public void testConnectionGranted() throws Exception {
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = randomBoolean() ? ShieldIpFilterRule.ACCEPT_ALL : IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
service.connectionGranted(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) {
@ -134,7 +134,7 @@ public class AuditTrailServiceTests extends ESTestCase {
@Test
public void testConnectionDenied() throws Exception {
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
service.connectionDenied(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) {

View File

@ -14,8 +14,9 @@ import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.inject.util.Providers;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.DummyTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.transport.LocalTransportAddress;
import org.elasticsearch.env.Environment;
@ -32,6 +33,7 @@ import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.test.InternalTestCluster;
import org.elasticsearch.test.ShieldIntegTestCase;
import org.elasticsearch.test.ShieldSettingsSource;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportInfo;
import org.elasticsearch.transport.TransportMessage;
import org.elasticsearch.transport.TransportRequest;
@ -175,9 +177,11 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
logger.info("--> settings: [{}]", settings.getAsMap().toString());
when(authService.authenticate(mock(RestRequest.class))).thenThrow(new UnsupportedOperationException(""));
when(authService.authenticate("_action", new LocalHostMockMessage(), user.user())).thenThrow(new UnsupportedOperationException(""));
Transport transport = mock(Transport.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(DummyTransportAddress.INSTANCE, DummyTransportAddress.INSTANCE));
Environment env = new Environment(settings);
auditor = new IndexAuditTrail(settings, user, env, authService, Providers.of(client()));
auditor = new IndexAuditTrail(settings, user, env, authService, transport, Providers.of(client()));
auditor.start(true);
}
@ -536,7 +540,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
public void testConnectionGranted() throws Exception {
initialize();
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditor.connectionGranted(inetAddress, "default", rule);
awaitIndexCreation(resolveIndexName());
@ -551,7 +555,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
@Test(expected = IndexNotFoundException.class)
public void testConnectionGranted_Muted() throws Exception {
initialize("connection_granted");
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditor.connectionGranted(inetAddress, "default", rule);
getClient().prepareExists(resolveIndexName()).execute().actionGet();
@ -561,7 +565,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
public void testConnectionDenied() throws Exception {
initialize();
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditor.connectionDenied(inetAddress, "default", rule);
awaitIndexCreation(resolveIndexName());
@ -576,7 +580,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
@Test(expected = IndexNotFoundException.class)
public void testConnectionDenied_Muted() throws Exception {
initialize("connection_denied");
InetAddress inetAddress = InetAddress.getLocalHost();
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditor.connectionDenied(inetAddress, "default", rule);
getClient().prepareExists(resolveIndexName()).execute().actionGet();
@ -587,8 +591,8 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
DateTime dateTime = ISODateTimeFormat.dateTimeParser().withZoneUTC().parseDateTime((String) hit.field("@timestamp").getValue());
assertThat(dateTime.isBefore(DateTime.now(DateTimeZone.UTC)), is(true));
assertThat(NetworkUtils.getLocalHost().getHostName(), equalTo(hit.field("node_host_name").getValue()));
assertThat(NetworkUtils.getLocalHost().getHostAddress(), equalTo(hit.field("node_host_address").getValue()));
assertThat(DummyTransportAddress.INSTANCE.getHost(), equalTo(hit.field("node_host_name").getValue()));
assertThat(DummyTransportAddress.INSTANCE.getAddress(), equalTo(hit.field("node_host_address").getValue()));
assertEquals(layer, hit.field("layer").getValue());
assertEquals(type, hit.field("event_type").getValue());
@ -602,13 +606,13 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
private static class RemoteHostMockMessage extends TransportMessage<RemoteHostMockMessage> {
RemoteHostMockMessage() throws Exception {
remoteAddress(new InetSocketTransportAddress(InetAddress.getLocalHost(), 1234));
remoteAddress(DummyTransportAddress.INSTANCE);
}
}
private static class RemoteHostMockTransportRequest extends TransportRequest {
RemoteHostMockTransportRequest() throws Exception {
remoteAddress(new InetSocketTransportAddress(InetAddress.getLocalHost(), 1234));
remoteAddress(DummyTransportAddress.INSTANCE);
}
}
@ -710,7 +714,7 @@ public class IndexAuditTrailTests extends ShieldIntegTestCase {
}
static String remoteHostAddress() throws Exception {
return InetAddress.getLocalHost().getHostAddress();
return DummyTransportAddress.INSTANCE.toString();
}
}

View File

@ -11,9 +11,7 @@ import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
import org.elasticsearch.common.transport.LocalTransportAddress;
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.transport.*;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.User;
import org.elasticsearch.shield.authc.AuthenticationToken;
@ -21,6 +19,7 @@ import org.elasticsearch.shield.rest.RemoteHostHeader;
import org.elasticsearch.shield.transport.filter.IPFilter;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.TransportMessage;
import org.junit.Before;
import org.junit.Test;
@ -99,6 +98,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
private String prefix;
private Settings settings;
private Transport transport;
@Before
public void init() throws Exception {
@ -107,16 +107,18 @@ public class LoggingAuditTrailTests extends ESTestCase {
.put("shield.audit.logfile.prefix.emit_node_host_name", randomBoolean())
.put("shield.audit.logfile.prefix.emit_node_name", randomBoolean())
.build();
prefix = LoggingAuditTrail.resolvePrefix(settings);
transport = mock(Transport.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(DummyTransportAddress.INSTANCE, DummyTransportAddress.INSTANCE));
prefix = LoggingAuditTrail.resolvePrefix(settings, transport);
}
@Test
public void testAnonymousAccessDenied_Transport() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);
auditTrail.anonymousAccessDenied("_action", message);
switch (level) {
case ERROR:
@ -150,7 +152,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.anonymousAccessDenied(request);
switch (level) {
case ERROR:
@ -171,9 +173,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed(new MockToken(), "_action", message);
switch (level) {
case ERROR:
@ -200,9 +202,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed_NoToken() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed("_action", message);
switch (level) {
case ERROR:
@ -233,7 +235,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed(new MockToken(), request);
switch (level) {
case ERROR:
@ -256,7 +258,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed(request);
switch (level) {
case ERROR:
@ -275,9 +277,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAuthenticationFailed_Realm() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message);
switch (level) {
case ERROR:
@ -304,7 +306,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
auditTrail.authenticationFailed("_realm", new MockToken(), request);
switch (level) {
case ERROR:
@ -323,9 +325,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(new User.Simple("_username", "r1"), "_action", message);
switch (level) {
case ERROR:
@ -354,9 +356,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted_InternalSystemAction() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(User.SYSTEM, "internal:_action", message);
switch (level) {
case ERROR:
@ -379,9 +381,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessGranted_InternalSystemAction_NonSystemUser() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessGranted(new User.Simple("_username"), "internal:_action", message);
switch (level) {
case ERROR:
@ -410,9 +412,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testAccessDenied() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest();
String origins = LoggingAuditTrail.originAttributes(message);
String origins = LoggingAuditTrail.originAttributes(message, transport);;
auditTrail.accessDenied(new User.Simple("_username", "r1"), "_action", message);
switch (level) {
case ERROR:
@ -439,8 +441,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testConnectionDenied() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
InetAddress inetAddress = InetAddress.getLocalHost();
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
auditTrail.connectionDenied(inetAddress, "default", rule);
switch (level) {
@ -460,8 +462,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testConnectionGranted() throws Exception {
for (Level level : Level.values()) {
CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, logger);
InetAddress inetAddress = InetAddress.getLocalHost();
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger);
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditTrail.connectionGranted(inetAddress, "default", rule);
switch (level) {
@ -482,7 +484,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
@Test
public void testOriginAttributes() throws Exception {
MockMessage message = new MockMessage();
String text = LoggingAuditTrail.originAttributes(message);
String text = LoggingAuditTrail.originAttributes(message, transport);;
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message);
if (restAddress != null) {
assertThat(text, equalTo("origin_type=[rest], origin_address=[" + restAddress + "]"));
@ -490,7 +492,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
}
TransportAddress address = message.remoteAddress();
if (address == null) {
assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + NetworkUtils.getLocalHost().getHostAddress() + "]"));
assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + transport.boundAddress().publishAddress().getAddress() + "]"));
return;
}

View File

@ -7,7 +7,6 @@ package org.elasticsearch.shield.transport.filter;
import com.google.common.net.InetAddresses;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
@ -45,12 +44,12 @@ public class IPFilterTests extends ESTestCase {
nodeSettingsService = mock(NodeSettingsService.class);
httpTransport = mock(HttpServerTransport.class);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9200);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(InetAddress.getLoopbackAddress().getHostAddress(), 9200);
when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress));
when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
transport = mock(Transport.class);
InetSocketTransportAddress address = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9300);
InetSocketTransportAddress address = new InetSocketTransportAddress(InetAddress.getLoopbackAddress().getHostAddress(), 9300);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address));
when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
}

View File

@ -7,7 +7,6 @@ package org.elasticsearch.shield.transport.netty;
import com.google.common.net.InetAddresses;
import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.network.NetworkUtils;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.transport.BoundTransportAddress;
import org.elasticsearch.common.transport.InetSocketTransportAddress;
@ -21,6 +20,7 @@ import org.jboss.netty.channel.*;
import org.junit.Before;
import org.junit.Test;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
@ -46,7 +46,7 @@ public class IPFilterNettyUpstreamHandlerTests extends ESTestCase {
boolean isHttpEnabled = randomBoolean();
Transport transport = mock(Transport.class);
InetSocketTransportAddress address = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9300);
InetSocketTransportAddress address = new InetSocketTransportAddress(InetAddress.getLoopbackAddress(), 9300);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(address, address));
when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
@ -55,7 +55,7 @@ public class IPFilterNettyUpstreamHandlerTests extends ESTestCase {
if (isHttpEnabled) {
HttpServerTransport httpTransport = mock(HttpServerTransport.class);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(NetworkUtils.getLocalHost().getHostAddress(), 9200);
InetSocketTransportAddress httpAddress = new InetSocketTransportAddress(InetAddress.getLoopbackAddress(), 9200);
when(httpTransport.boundAddress()).thenReturn(new BoundTransportAddress(httpAddress, httpAddress));
when(httpTransport.lifecycleState()).thenReturn(Lifecycle.State.STARTED);
ipFilter.setHttpServerTransport(httpTransport);

View File

@ -5,6 +5,7 @@
*/
package org.elasticsearch.shield.transport.netty;
import org.apache.lucene.util.LuceneTestCase;
import org.elasticsearch.client.Client;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.test.ShieldIntegTestCase;
@ -16,6 +17,7 @@ import java.nio.file.Path;
import static org.elasticsearch.common.settings.Settings.settingsBuilder;
import static org.hamcrest.CoreMatchers.is;
@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/x-plugins/issues/468")
public class IPHostnameVerificationTests extends ShieldIntegTestCase {
Path keystore;