Internal: Remove use of Transport in audit trails

Both logfile and index audit trails currently depend on injection of
Transport in order to find the bound address of the local node. However,
the ClusterService provides access to information about the local node,
including the bound addresses. This change makes the audit trails use
the cluster service, and also makes the logging audit trail not use a
lifecycle.

Original commit: elastic/x-pack-elasticsearch@d747d64ee1
This commit is contained in:
Ryan Ernst 2016-07-13 20:43:55 -07:00
parent 4224d70986
commit f481dea1d0
9 changed files with 222 additions and 297 deletions

View File

@ -18,6 +18,7 @@ import java.util.function.Function;
import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.support.ActionFilter; import org.elasticsearch.action.support.ActionFilter;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Booleans; import org.elasticsearch.common.Booleans;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.component.LifecycleComponent; import org.elasticsearch.common.component.LifecycleComponent;
@ -179,11 +180,6 @@ public class Security implements ActionPlugin {
return Collections.emptyList(); return Collections.emptyList();
} }
List<Class<? extends LifecycleComponent>> list = new ArrayList<>(); List<Class<? extends LifecycleComponent>> list = new ArrayList<>();
//TODO why only focus on file audit logs? shouldn't we just check if audit trail is enabled in general?
if (AuditTrailModule.fileAuditLoggingEnabled(settings) == true) {
list.add(LoggingAuditTrail.class);
}
list.add(SecurityLicensee.class); list.add(SecurityLicensee.class);
list.add(FileRolesStore.class); list.add(FileRolesStore.class);
list.add(Realms.class); list.add(Realms.class);

View File

@ -55,8 +55,7 @@ public class AuditTrailModule extends AbstractSecurityModule.Node {
bind(AuditTrailService.class).asEagerSingleton(); bind(AuditTrailService.class).asEagerSingleton();
bind(AuditTrail.class).to(AuditTrailService.class); bind(AuditTrail.class).to(AuditTrailService.class);
Multibinder<AuditTrail> binder = Multibinder.newSetBinder(binder(), AuditTrail.class); Multibinder<AuditTrail> binder = Multibinder.newSetBinder(binder(), AuditTrail.class);
Set<String> uniqueOutputs = Sets.newHashSet(outputs); for (String output : outputs) {
for (String output : uniqueOutputs) {
switch (output) { switch (output) {
case LoggingAuditTrail.NAME: case LoggingAuditTrail.NAME:
binder.addBinding().to(LoggingAuditTrail.class); binder.addBinding().to(LoggingAuditTrail.class);
@ -67,7 +66,7 @@ public class AuditTrailModule extends AbstractSecurityModule.Node {
bind(IndexAuditTrail.class).asEagerSingleton(); bind(IndexAuditTrail.class).asEagerSingleton();
break; break;
default: default:
throw new ElasticsearchException("unknown audit trail output [" + output + "]"); throw new IllegalArgumentException("unknown audit trail output [" + output + "]");
} }
} }
} }

View File

@ -19,6 +19,7 @@ import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterChangedEvent;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.ClusterState;
import org.elasticsearch.cluster.ClusterStateListener; import org.elasticsearch.cluster.ClusterStateListener;
@ -153,7 +154,6 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
private final Provider<InternalClient> clientProvider; private final Provider<InternalClient> clientProvider;
private final BlockingQueue<Message> eventQueue; private final BlockingQueue<Message> eventQueue;
private final QueueConsumer queueConsumer; private final QueueConsumer queueConsumer;
private final Transport transport;
private final ThreadPool threadPool; private final ThreadPool threadPool;
private final Lock putMappingLock = new ReentrantLock(); private final Lock putMappingLock = new ReentrantLock();
private final ClusterService clusterService; private final ClusterService clusterService;
@ -172,11 +172,10 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
} }
@Inject @Inject
public IndexAuditTrail(Settings settings, Transport transport, public IndexAuditTrail(Settings settings, Provider<InternalClient> clientProvider, ThreadPool threadPool,
Provider<InternalClient> clientProvider, ThreadPool threadPool, ClusterService clusterService) { ClusterService clusterService) {
super(settings); super(settings);
this.clientProvider = clientProvider; this.clientProvider = clientProvider;
this.transport = transport;
this.threadPool = threadPool; this.threadPool = threadPool;
this.clusterService = clusterService; this.clusterService = clusterService;
this.nodeName = settings.get("name"); this.nodeName = settings.get("name");
@ -277,8 +276,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
*/ */
public void start(boolean master) { public void start(boolean master) {
if (state.compareAndSet(State.INITIALIZED, State.STARTING)) { if (state.compareAndSet(State.INITIALIZED, State.STARTING)) {
this.nodeHostName = transport.boundAddress().publishAddress().getHost(); this.nodeHostName = clusterService.localNode().getHostName();
this.nodeHostAddress = transport.boundAddress().publishAddress().getAddress(); this.nodeHostAddress = clusterService.localNode().getHostAddress();
if (client == null) { if (client == null) {
initializeClient(); initializeClient();
@ -545,7 +544,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
Message msg = new Message().start(); Message msg = new Message().start();
common("transport", type, msg.builder); common("transport", type, msg.builder);
originAttributes(message, msg.builder, transport, threadPool.getThreadContext()); originAttributes(message, msg.builder, clusterService.localNode(), threadPool.getThreadContext());
if (action != null) { if (action != null) {
msg.builder.field(Field.ACTION, action); msg.builder.field(Field.ACTION, action);
@ -577,7 +576,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
Message msg = new Message().start(); Message msg = new Message().start();
common("transport", type, msg.builder); common("transport", type, msg.builder);
originAttributes(message, msg.builder, transport, threadPool.getThreadContext()); originAttributes(message, msg.builder, clusterService.localNode(), threadPool.getThreadContext());
if (action != null) { if (action != null) {
msg.builder.field(Field.ACTION, action); msg.builder.field(Field.ACTION, action);
@ -672,8 +671,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
return builder; return builder;
} }
private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder, Transport transport, ThreadContext private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder,
threadContext) throws IOException { DiscoveryNode localNode, ThreadContext threadContext) throws IOException {
// first checking if the message originated in a rest call // first checking if the message originated in a rest call
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(threadContext); InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(threadContext);
@ -698,7 +697,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail, Cl
// the call was originated locally on this node // the call was originated locally on this node
builder.field(Field.ORIGIN_TYPE, "local_node"); builder.field(Field.ORIGIN_TYPE, "local_node");
builder.field(Field.ORIGIN_ADDRESS, transport.boundAddress().publishAddress().getAddress()); builder.field(Field.ORIGIN_ADDRESS, localNode.getHostAddress());
return builder; return builder;
} }

View File

@ -5,6 +5,9 @@
*/ */
package org.elasticsearch.xpack.security.audit.logfile; package org.elasticsearch.xpack.security.audit.logfile;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.component.AbstractLifecycleComponent;
import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.Lifecycle;
import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.component.LifecycleListener;
@ -44,7 +47,7 @@ import static org.elasticsearch.xpack.security.Security.setting;
/** /**
* *
*/ */
public class LoggingAuditTrail extends AbstractLifecycleComponent implements AuditTrail { public class LoggingAuditTrail extends AbstractComponent implements AuditTrail {
public static final String NAME = "logfile"; public static final String NAME = "logfile";
public static final Setting<Boolean> HOST_ADDRESS_SETTING = public static final Setting<Boolean> HOST_ADDRESS_SETTING =
@ -55,7 +58,7 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
Setting.boolSetting(setting("audit.logfile.prefix.emit_node_name"), true, Property.NodeScope); Setting.boolSetting(setting("audit.logfile.prefix.emit_node_name"), true, Property.NodeScope);
private final ESLogger logger; private final ESLogger logger;
private final Transport transport; private final ClusterService clusterService;
private final ThreadContext threadContext; private final ThreadContext threadContext;
private String prefix; private String prefix;
@ -66,43 +69,22 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
} }
@Inject @Inject
public LoggingAuditTrail(Settings settings, Transport transport, ThreadPool threadPool) { public LoggingAuditTrail(Settings settings, ClusterService clusterService, ThreadPool threadPool) {
this(settings, transport, Loggers.getLogger(LoggingAuditTrail.class), threadPool.getThreadContext()); this(settings, clusterService, Loggers.getLogger(LoggingAuditTrail.class), threadPool.getThreadContext());
} }
LoggingAuditTrail(Settings settings, Transport transport, ESLogger logger, ThreadContext threadContext) { LoggingAuditTrail(Settings settings, ClusterService clusterService, ESLogger logger, ThreadContext threadContext) {
this("", settings, transport, logger, threadContext);
}
LoggingAuditTrail(String prefix, Settings settings, Transport transport, ESLogger logger, ThreadContext threadContext) {
super(settings); super(settings);
this.logger = logger; this.logger = logger;
this.prefix = prefix; this.clusterService = clusterService;
this.transport = transport;
this.threadContext = threadContext; this.threadContext = threadContext;
} }
private String getPrefix() {
@Override if (prefix == null) {
protected void doStart() { prefix = resolvePrefix(settings, clusterService.localNode());
if (transport.lifecycleState() == Lifecycle.State.STARTED) {
prefix = resolvePrefix(settings, transport);
} else {
transport.addLifecycleListener(new LifecycleListener() {
@Override
public void afterStart() {
prefix = resolvePrefix(settings, transport);
} }
}); return prefix;
}
}
@Override
protected void doStop() {
}
@Override
protected void doClose() {
} }
@Override @Override
@ -110,19 +92,20 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action, indices, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action, indices,
message.getClass().getSimpleName());
} else { } else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}], indices=[{}]", getPrefix(),
transport, threadContext), action, indices); originAttributes(message, clusterService.localNode(), threadContext), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, logger.debug("{}[transport] [anonymous_access_denied]\t{}, action=[{}], request=[{}]", getPrefix(),
transport, threadContext), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action, message.getClass().getSimpleName());
} else { } else {
logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", prefix, originAttributes(message, transport, logger.warn("{}[transport] [anonymous_access_denied]\t{}, action=[{}]", getPrefix(),
threadContext), action); originAttributes(message, clusterService.localNode(), threadContext), action);
} }
} }
} }
@ -130,10 +113,10 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void anonymousAccessDenied(RestRequest request) { public void anonymousAccessDenied(RestRequest request) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[rest] [anonymous_access_denied]\t{}, uri=[{}], request_body=[{}]", prefix, hostAttributes(request), request logger.debug("{}[rest] [anonymous_access_denied]\t{}, uri=[{}], request_body=[{}]", getPrefix(),
.uri(), restRequestContent(request)); hostAttributes(request), request.uri(), restRequestContent(request));
} else { } else {
logger.warn("{}[rest] [anonymous_access_denied]\t{}, uri=[{}]", prefix, hostAttributes(request), request.uri()); logger.warn("{}[rest] [anonymous_access_denied]\t{}, uri=[{}]", getPrefix(), hostAttributes(request), request.uri());
} }
} }
@ -143,19 +126,20 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]", logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}], request=[{}]",
prefix, originAttributes(message, transport, threadContext), token.principal(), action, indices, message.getClass getPrefix(), originAttributes(message, clusterService.localNode(), threadContext), token.principal(),
().getSimpleName()); action, indices, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", prefix, logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], indices=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), token.principal(), action, indices); originAttributes(message, clusterService.localNode(), threadContext), token.principal(), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), token.principal(), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), token.principal(), action,
message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", prefix, originAttributes(message, logger.error("{}[transport] [authentication_failed]\t{}, principal=[{}], action=[{}]", getPrefix(),
transport, threadContext), token.principal(), action); originAttributes(message, clusterService.localNode(), threadContext), token.principal(), action);
} }
} }
} }
@ -163,10 +147,10 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void authenticationFailed(RestRequest request) { public void authenticationFailed(RestRequest request) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[rest] [authentication_failed]\t{}, uri=[{}], request_body=[{}]", prefix, hostAttributes(request), request logger.debug("{}[rest] [authentication_failed]\t{}, uri=[{}], request_body=[{}]", getPrefix(), hostAttributes(request),
.uri(), restRequestContent(request)); request.uri(), restRequestContent(request));
} else { } else {
logger.error("{}[rest] [authentication_failed]\t{}, uri=[{}]", prefix, hostAttributes(request), request.uri()); logger.error("{}[rest] [authentication_failed]\t{}, uri=[{}]", getPrefix(), hostAttributes(request), request.uri());
} }
} }
@ -175,19 +159,20 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action, indices, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action, indices,
message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", prefix, originAttributes(message, logger.error("{}[transport] [authentication_failed]\t{}, action=[{}], indices=[{}]", getPrefix(),
transport, threadContext), action, indices); originAttributes(message, clusterService.localNode(), threadContext), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", prefix, originAttributes(message, logger.debug("{}[transport] [authentication_failed]\t{}, action=[{}], request=[{}]", getPrefix(),
transport, threadContext), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action, message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", prefix, originAttributes(message, transport, logger.error("{}[transport] [authentication_failed]\t{}, action=[{}]", getPrefix(),
threadContext), action); originAttributes(message, clusterService.localNode(), threadContext), action);
} }
} }
} }
@ -195,11 +180,11 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void authenticationFailed(AuthenticationToken token, RestRequest request) { public void authenticationFailed(AuthenticationToken token, RestRequest request) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[rest] [authentication_failed]\t{}, principal=[{}], uri=[{}], request_body=[{}]", prefix, hostAttributes logger.debug("{}[rest] [authentication_failed]\t{}, principal=[{}], uri=[{}], request_body=[{}]", getPrefix(),
(request), token.principal(), request.uri(), restRequestContent(request)); hostAttributes(request), token.principal(), request.uri(), restRequestContent(request));
} else { } else {
logger.error("{}[rest] [authentication_failed]\t{}, principal=[{}], uri=[{}]", prefix, hostAttributes(request), token logger.error("{}[rest] [authentication_failed]\t{}, principal=[{}], uri=[{}]", getPrefix(), hostAttributes(request),
.principal(), request.uri()); token.principal(), request.uri());
} }
} }
@ -209,12 +194,12 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], " + logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], indices=[{}], " +
"request=[{}]", prefix, realm, originAttributes(message, transport, threadContext), token.principal(), action, "request=[{}]", getPrefix(), realm, originAttributes(message, clusterService.localNode(), threadContext),
indices, message.getClass().getSimpleName()); token.principal(), action, indices, message.getClass().getSimpleName());
} else { } else {
logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]", prefix, logger.trace("{}[transport] [authentication_failed]\trealm=[{}], {}, principal=[{}], action=[{}], request=[{}]",
realm, originAttributes(message, transport, threadContext), token.principal(), action, message.getClass() getPrefix(), realm, originAttributes(message, clusterService.localNode(), threadContext), token.principal(),
.getSimpleName()); action, message.getClass().getSimpleName());
} }
} }
} }
@ -222,8 +207,8 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void authenticationFailed(String realm, AuthenticationToken token, RestRequest request) { public void authenticationFailed(String realm, AuthenticationToken token, RestRequest request) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("{}[rest] [authentication_failed]\trealm=[{}], {}, principal=[{}], uri=[{}], request_body=[{}]", prefix, realm, logger.trace("{}[rest] [authentication_failed]\trealm=[{}], {}, principal=[{}], uri=[{}], request_body=[{}]", getPrefix(),
hostAttributes(request), token.principal(), request.uri(), restRequestContent(request)); realm, hostAttributes(request), token.principal(), request.uri(), restRequestContent(request));
} }
} }
@ -235,12 +220,12 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
if ((SystemUser.is(user) && SystemPrivilege.INSTANCE.predicate().test(action)) || XPackUser.is(user)) { if ((SystemUser.is(user) && SystemPrivilege.INSTANCE.predicate().test(action)) || XPackUser.is(user)) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
if (indices != null) { if (indices != null) {
logger.trace("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}], request=[{}]", prefix, logger.trace("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, indices, originAttributes(message, clusterService.localNode(), threadContext), principal(user), action, indices,
message.getClass().getSimpleName()); message.getClass().getSimpleName());
} else { } else {
logger.trace("{}[transport] [access_granted]\t{}, {}, action=[{}], request=[{}]", prefix, logger.trace("{}[transport] [access_granted]\t{}, {}, action=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, originAttributes(message, clusterService.localNode(), threadContext), principal(user), action,
message.getClass().getSimpleName()); message.getClass().getSimpleName());
} }
} }
@ -249,20 +234,21 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, indices, originAttributes(message, clusterService.localNode(), threadContext), principal(user), action, indices,
message.getClass().getSimpleName()); message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}]", prefix, logger.info("{}[transport] [access_granted]\t{}, {}, action=[{}], indices=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, indices); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_granted]\t{}, {}, action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [access_granted]\t{}, {}, action=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action,
message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [access_granted]\t{}, {}, action=[{}]", prefix, logger.info("{}[transport] [access_granted]\t{}, {}, action=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action);
} }
} }
} }
@ -272,20 +258,21 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, {}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [access_denied]\t{}, {}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, indices, originAttributes(message, clusterService.localNode(), threadContext), principal(user), action, indices,
message.getClass().getSimpleName()); message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [access_denied]\t{}, {}, action=[{}], indices=[{}]", prefix, logger.error("{}[transport] [access_denied]\t{}, {}, action=[{}], indices=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, indices); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [access_denied]\t{}, {}, action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [access_denied]\t{}, {}, action=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action,
message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [access_denied]\t{}, {}, action=[{}]", prefix, logger.error("{}[transport] [access_denied]\t{}, {}, action=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), principal(user), action); originAttributes(message, clusterService.localNode(), threadContext), principal(user), action);
} }
} }
} }
@ -293,10 +280,10 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void tamperedRequest(RestRequest request) { public void tamperedRequest(RestRequest request) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[rest] [tampered_request]\t{}, uri=[{}], request_body=[{}]", prefix, hostAttributes(request), request.uri(), logger.debug("{}[rest] [tampered_request]\t{}, uri=[{}], request_body=[{}]", getPrefix(), hostAttributes(request),
restRequestContent(request)); request.uri(), restRequestContent(request));
} else { } else {
logger.error("{}[rest] [tampered_request]\t{}, uri=[{}]", prefix, hostAttributes(request), request.uri()); logger.error("{}[rest] [tampered_request]\t{}, uri=[{}]", getPrefix(), hostAttributes(request), request.uri());
} }
} }
@ -305,19 +292,21 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(message); String indices = indicesString(message);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [tampered_request]\t{}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [tampered_request]\t{}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action, indices, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action, indices,
message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [tampered_request]\t{}, action=[{}], indices=[{}]", prefix, logger.error("{}[transport] [tampered_request]\t{}, action=[{}], indices=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action, indices); originAttributes(message, clusterService.localNode(), threadContext), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [tampered_request]\t{}, action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [tampered_request]\t{}, action=[{}], request=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action, message.getClass().getSimpleName()); originAttributes(message, clusterService.localNode(), threadContext), action,
message.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [tampered_request]\t{}, action=[{}]", prefix, logger.error("{}[transport] [tampered_request]\t{}, action=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), action); originAttributes(message, clusterService.localNode(), threadContext), action);
} }
} }
} }
@ -327,20 +316,21 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
String indices = indicesString(request); String indices = indicesString(request);
if (indices != null) { if (indices != null) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [tampered_request]\t{}, {}, action=[{}], indices=[{}], request=[{}]", prefix, logger.debug("{}[transport] [tampered_request]\t{}, {}, action=[{}], indices=[{}], request=[{}]", getPrefix(),
originAttributes(request, transport, threadContext), principal(user), action, indices, originAttributes(request, clusterService.localNode(), threadContext), principal(user), action, indices,
request.getClass().getSimpleName()); request.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [tampered_request]\t{}, {}, action=[{}], indices=[{}]", prefix, logger.error("{}[transport] [tampered_request]\t{}, {}, action=[{}], indices=[{}]", getPrefix(),
originAttributes(request, transport, threadContext), principal(user), action, indices); originAttributes(request, clusterService.localNode(), threadContext), principal(user), action, indices);
} }
} else { } else {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [tampered_request]\t{}, {}, action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [tampered_request]\t{}, {}, action=[{}], request=[{}]", getPrefix(),
originAttributes(request, transport, threadContext), principal(user), action, request.getClass().getSimpleName()); originAttributes(request, clusterService.localNode(), threadContext), principal(user), action,
request.getClass().getSimpleName());
} else { } else {
logger.error("{}[transport] [tampered_request]\t{}, {}, action=[{}]", prefix, logger.error("{}[transport] [tampered_request]\t{}, {}, action=[{}]", getPrefix(),
originAttributes(request, transport, threadContext), principal(user), action); originAttributes(request, clusterService.localNode(), threadContext), principal(user), action);
} }
} }
} }
@ -348,48 +338,50 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
@Override @Override
public void connectionGranted(InetAddress inetAddress, String profile, SecurityIpFilterRule rule) { public void connectionGranted(InetAddress inetAddress, String profile, SecurityIpFilterRule rule) {
if (logger.isTraceEnabled()) { if (logger.isTraceEnabled()) {
logger.trace("{}[ip_filter] [connection_granted]\torigin_address=[{}], transport_profile=[{}], rule=[{}]", prefix, logger.trace("{}[ip_filter] [connection_granted]\torigin_address=[{}], transport_profile=[{}], rule=[{}]", getPrefix(),
NetworkAddress.format(inetAddress), profile, rule); NetworkAddress.format(inetAddress), profile, rule);
} }
} }
@Override @Override
public void connectionDenied(InetAddress inetAddress, String profile, SecurityIpFilterRule rule) { public void connectionDenied(InetAddress inetAddress, String profile, SecurityIpFilterRule rule) {
logger.error("{}[ip_filter] [connection_denied]\torigin_address=[{}], transport_profile=[{}], rule=[{}]", prefix, logger.error("{}[ip_filter] [connection_denied]\torigin_address=[{}], transport_profile=[{}], rule=[{}]", getPrefix(),
NetworkAddress.format(inetAddress), profile, rule); NetworkAddress.format(inetAddress), profile, rule);
} }
@Override @Override
public void runAsGranted(User user, String action, TransportMessage message) { public void runAsGranted(User user, String action, TransportMessage message) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [run_as_granted]\t{}, principal=[{}], run_as_principal=[{}], action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [run_as_granted]\t{}, principal=[{}], run_as_principal=[{}], action=[{}], request=[{}]",
originAttributes(message, transport, threadContext), user.principal(), user.runAs().principal(), action, getPrefix(), originAttributes(message, clusterService.localNode(), threadContext), user.principal(),
message.getClass().getSimpleName()); user.runAs().principal(), action, message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [run_as_granted]\t{}, principal=[{}], run_as_principal=[{}], action=[{}]", prefix, logger.info("{}[transport] [run_as_granted]\t{}, principal=[{}], run_as_principal=[{}], action=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), user.principal(), user.runAs().principal(), action); originAttributes(message, clusterService.localNode(), threadContext), user.principal(),
user.runAs().principal(), action);
} }
} }
@Override @Override
public void runAsDenied(User user, String action, TransportMessage message) { public void runAsDenied(User user, String action, TransportMessage message) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[transport] [run_as_denied]\t{}, principal=[{}], run_as_principal=[{}], action=[{}], request=[{}]", prefix, logger.debug("{}[transport] [run_as_denied]\t{}, principal=[{}], run_as_principal=[{}], action=[{}], request=[{}]",
originAttributes(message, transport, threadContext), user.principal(), user.runAs().principal(), action, getPrefix(), originAttributes(message, clusterService.localNode(), threadContext), user.principal(),
message.getClass().getSimpleName()); user.runAs().principal(), action, message.getClass().getSimpleName());
} else { } else {
logger.info("{}[transport] [run_as_denied]\t{}, principal=[{}], run_as_principal=[{}], action=[{}]", prefix, logger.info("{}[transport] [run_as_denied]\t{}, principal=[{}], run_as_principal=[{}], action=[{}]", getPrefix(),
originAttributes(message, transport, threadContext), user.principal(), user.runAs().principal(), action); originAttributes(message, clusterService.localNode(), threadContext), user.principal(),
user.runAs().principal(), action);
} }
} }
@Override @Override
public void runAsDenied(User user, RestRequest request) { public void runAsDenied(User user, RestRequest request) {
if (logger.isDebugEnabled()) { if (logger.isDebugEnabled()) {
logger.debug("{}[rest] [run_as_denied]\t{}, principal=[{}], uri=[{}], request_body=[{}]", prefix, logger.debug("{}[rest] [run_as_denied]\t{}, principal=[{}], uri=[{}], request_body=[{}]", getPrefix(),
hostAttributes(request), user.principal(), request.uri(), restRequestContent(request)); hostAttributes(request), user.principal(), request.uri(), restRequestContent(request));
} else { } else {
logger.info("{}[transport] [run_as_denied]\t{}, principal=[{}], uri=[{}]", prefix, logger.info("{}[transport] [run_as_denied]\t{}, principal=[{}], uri=[{}]", getPrefix(),
hostAttributes(request), user.principal(), request.uri()); hostAttributes(request), user.principal(), request.uri());
} }
} }
@ -405,7 +397,7 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
return "origin_address=[" + formattedAddress + "]"; return "origin_address=[" + formattedAddress + "]";
} }
static String originAttributes(TransportMessage message, Transport transport, ThreadContext threadContext) { static String originAttributes(TransportMessage message, DiscoveryNode localNode, ThreadContext threadContext) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
// first checking if the message originated in a rest call // first checking if the message originated in a rest call
@ -433,21 +425,21 @@ public class LoggingAuditTrail extends AbstractLifecycleComponent implements Aud
// the call was originated locally on this node // the call was originated locally on this node
return builder.append("origin_type=[local_node], origin_address=[") return builder.append("origin_type=[local_node], origin_address=[")
.append(transport.boundAddress().publishAddress().getAddress()) .append(localNode.getHostAddress())
.append("]") .append("]")
.toString(); .toString();
} }
static String resolvePrefix(Settings settings, Transport transport) { static String resolvePrefix(Settings settings, DiscoveryNode localNode) {
StringBuilder builder = new StringBuilder(); StringBuilder builder = new StringBuilder();
if (HOST_ADDRESS_SETTING.get(settings)) { if (HOST_ADDRESS_SETTING.get(settings)) {
String address = transport.boundAddress().publishAddress().getAddress(); String address = localNode.getHostAddress();
if (address != null) { if (address != null) {
builder.append("[").append(address).append("] "); builder.append("[").append(address).append("] ");
} }
} }
if (HOST_NAME_SETTING.get(settings)) { if (HOST_NAME_SETTING.get(settings)) {
String hostName = transport.boundAddress().publishAddress().getHost(); String hostName = localNode.getHostName();
if (hostName != null) { if (hostName != null) {
builder.append("[").append(hostName).append("] "); builder.append("[").append(hostName).append("] ");
} }

View File

@ -5,92 +5,46 @@
*/ */
package org.elasticsearch.xpack.security.audit; package org.elasticsearch.xpack.security.audit;
import org.elasticsearch.common.inject.Guice; import org.elasticsearch.common.inject.ModuleTestCase;
import org.elasticsearch.common.inject.Injector;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.network.NetworkService;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsModule; import org.elasticsearch.xpack.security.audit.index.IndexAuditTrail;
import org.elasticsearch.indices.breaker.CircuitBreakerService;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.local.LocalTransport;
import org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrail; import org.elasticsearch.xpack.security.audit.logfile.LoggingAuditTrail;
import static org.hamcrest.Matchers.instanceOf; public class AuditTrailModuleTests extends ModuleTestCase {
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.notNullValue;
public class AuditTrailModuleTests extends ESTestCase {
public void testEnabled() throws Exception { public void testEnabled() throws Exception {
Settings settings = Settings.builder() Settings settings = Settings.builder().put(AuditTrailModule.ENABLED_SETTING.getKey(), true).build();
.put("client.type", "node") AuditTrailModule module = new AuditTrailModule(settings);
.put(AuditTrailModule.ENABLED_SETTING.getKey(), false) assertBinding(module, AuditTrail.class, AuditTrailService.class);
.build(); assertSetMultiBinding(module, AuditTrail.class, LoggingAuditTrail.class);
SettingsModule settingsModule = new SettingsModule(settings, AuditTrailModule.ENABLED_SETTING);
Injector injector = Guice.createInjector(settingsModule, new AuditTrailModule(settings));
AuditTrail auditTrail = injector.getInstance(AuditTrail.class);
assertThat(auditTrail, is(AuditTrail.NOOP));
} }
public void testDisabledByDefault() throws Exception { public void testDisabledByDefault() throws Exception {
Settings settings = Settings.builder() AuditTrailModule module = new AuditTrailModule(Settings.EMPTY);
.put("client.type", "node").build(); assertInstanceBinding(module, AuditTrail.class, x -> x == AuditTrail.NOOP);
Injector injector = Guice.createInjector(new SettingsModule(settings), new AuditTrailModule(settings));
AuditTrail auditTrail = injector.getInstance(AuditTrail.class);
assertThat(auditTrail, is(AuditTrail.NOOP));
} }
public void testLogfile() throws Exception { public void testIndexAuditTrail() throws Exception {
Settings settings = Settings.builder() Settings settings = Settings.builder()
.put(AuditTrailModule.ENABLED_SETTING.getKey(), true) .put(AuditTrailModule.ENABLED_SETTING.getKey(), true)
.put("client.type", "node") .put(AuditTrailModule.OUTPUTS_SETTING.getKey(), "index").build();
.build(); AuditTrailModule module = new AuditTrailModule(settings);
ThreadPool pool = new TestThreadPool("testLogFile"); assertSetMultiBinding(module, AuditTrail.class, IndexAuditTrail.class);
try {
SettingsModule settingsModule = new SettingsModule(settings, AuditTrailModule.ENABLED_SETTING);
Injector injector = Guice.createInjector(
settingsModule,
new NetworkModule(new NetworkService(settings), settings, false, new NamedWriteableRegistry()) {
@Override
protected void configure() {
bind(Transport.class).to(LocalTransport.class).asEagerSingleton();
}
},
new AuditTrailModule(settings),
b -> {
b.bind(CircuitBreakerService.class).toInstance(Node.createCircuitBreakerService(settingsModule.getSettings(),
settingsModule.getClusterSettings()));
b.bind(ThreadPool.class).toInstance(pool);
}
);
AuditTrail auditTrail = injector.getInstance(AuditTrail.class);
assertThat(auditTrail, instanceOf(AuditTrailService.class));
AuditTrailService service = (AuditTrailService) auditTrail;
assertThat(service.auditTrails, notNullValue());
assertThat(service.auditTrails.length, is(1));
assertThat(service.auditTrails[0], instanceOf(LoggingAuditTrail.class));
} finally {
pool.shutdown();
} }
public void testIndexAndLoggingAuditTrail() throws Exception {
Settings settings = Settings.builder()
.put(AuditTrailModule.ENABLED_SETTING.getKey(), true)
.put(AuditTrailModule.OUTPUTS_SETTING.getKey(), "index,logfile").build();
AuditTrailModule module = new AuditTrailModule(settings);
assertSetMultiBinding(module, AuditTrail.class, IndexAuditTrail.class, LoggingAuditTrail.class);
} }
public void testUnknownOutput() throws Exception { public void testUnknownOutput() throws Exception {
Settings settings = Settings.builder() Settings settings = Settings.builder()
.put(AuditTrailModule.ENABLED_SETTING.getKey(), true) .put(AuditTrailModule.ENABLED_SETTING.getKey(), true)
.put(AuditTrailModule.OUTPUTS_SETTING.getKey() , "foo") .put(AuditTrailModule.OUTPUTS_SETTING.getKey(), "foo").build();
.put("client.type", "node") AuditTrailModule module = new AuditTrailModule(settings);
.build(); assertBindingFailure(module, "unknown audit trail output [foo]");
SettingsModule settingsModule = new SettingsModule(settings, AuditTrailModule.ENABLED_SETTING, AuditTrailModule.OUTPUTS_SETTING);
try {
Guice.createInjector(settingsModule, new AuditTrailModule(settings));
fail("Expect initialization to fail when an unknown audit trail output is configured");
} catch (Exception e) {
// expected
}
} }
} }

View File

@ -13,6 +13,7 @@ import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient; import org.elasticsearch.client.FilterClient;
import org.elasticsearch.client.transport.TransportClient; import org.elasticsearch.client.transport.TransportClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.util.Providers; import org.elasticsearch.common.inject.util.Providers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -49,7 +50,7 @@ public class IndexAuditTrailMutedTests extends ESTestCase {
private InternalClient client; private InternalClient client;
private TransportClient transportClient; private TransportClient transportClient;
private ThreadPool threadPool; private ThreadPool threadPool;
private Transport transport; private ClusterService clusterService;
private IndexAuditTrail auditTrail; private IndexAuditTrail auditTrail;
private AtomicBoolean messageEnqueued; private AtomicBoolean messageEnqueued;
@ -57,9 +58,10 @@ public class IndexAuditTrailMutedTests extends ESTestCase {
@Before @Before
public void setup() { public void setup() {
transport = mock(Transport.class); DiscoveryNode localNode = mock(DiscoveryNode.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(new TransportAddress[] { LocalTransportAddress.buildUnique() }, when(localNode.getHostAddress()).thenReturn(LocalTransportAddress.buildUnique().toString());
LocalTransportAddress.buildUnique())); clusterService = mock(ClusterService.class);
when(clusterService.localNode()).thenReturn(localNode);
threadPool = new TestThreadPool("index audit trail tests"); threadPool = new TestThreadPool("index audit trail tests");
transportClient = TransportClient.builder().settings(Settings.builder().put("transport.type", "local")).build(); transportClient = TransportClient.builder().settings(Settings.builder().put("transport.type", "local")).build();
@ -257,7 +259,7 @@ public class IndexAuditTrailMutedTests extends ESTestCase {
IndexAuditTrail createAuditTrail(String[] excludes) { IndexAuditTrail createAuditTrail(String[] excludes) {
Settings settings = IndexAuditTrailTests.levelSettings(null, excludes); Settings settings = IndexAuditTrailTests.levelSettings(null, excludes);
auditTrail = new IndexAuditTrail(settings, transport, Providers.of(client), threadPool, mock(ClusterService.class)) { auditTrail = new IndexAuditTrail(settings, Providers.of(client), threadPool, clusterService) {
@Override @Override
void putTemplate(Settings settings) { void putTemplate(Settings settings) {
// make this a no-op so we don't have to stub out unnecessary client activities // make this a no-op so we don't have to stub out unnecessary client activities

View File

@ -15,6 +15,7 @@ import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.client.Requests; import org.elasticsearch.client.Requests;
import org.elasticsearch.cluster.health.ClusterHealthStatus; import org.elasticsearch.cluster.health.ClusterHealthStatus;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Priority; import org.elasticsearch.common.Priority;
import org.elasticsearch.common.inject.util.Providers; import org.elasticsearch.common.inject.util.Providers;
@ -268,13 +269,14 @@ public class IndexAuditTrailTests extends SecurityIntegTestCase {
Settings settings = builder.put(settings(rollover, includes, excludes)).build(); Settings settings = builder.put(settings(rollover, includes, excludes)).build();
logger.info("--> settings: [{}]", settings.getAsMap().toString()); logger.info("--> settings: [{}]", settings.getAsMap().toString());
Transport transport = mock(Transport.class); DiscoveryNode localNode = mock(DiscoveryNode.class);
BoundTransportAddress boundTransportAddress = new BoundTransportAddress(new TransportAddress[]{ remoteHostAddress()}, when(localNode.getHostAddress()).thenReturn(remoteHostAddress().getAddress());
remoteHostAddress()); when(localNode.getHostName()).thenReturn(remoteHostAddress().getHost());
when(transport.boundAddress()).thenReturn(boundTransportAddress); ClusterService clusterService = mock(ClusterService.class);
when(clusterService.localNode()).thenReturn(localNode);
threadPool = new TestThreadPool("index audit trail tests"); threadPool = new TestThreadPool("index audit trail tests");
enqueuedMessage = new SetOnce<>(); enqueuedMessage = new SetOnce<>();
auditor = new IndexAuditTrail(settings, transport, Providers.of(internalClient()), threadPool, mock(ClusterService.class)) { auditor = new IndexAuditTrail(settings, Providers.of(internalClient()), threadPool, clusterService) {
@Override @Override
void enqueue(Message message, String type) { void enqueue(Message message, String type) {
enqueuedMessage.set(message); enqueuedMessage.set(message);

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.audit.index; package org.elasticsearch.xpack.security.audit.index;
import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsResponse; import org.elasticsearch.action.admin.indices.mapping.get.GetMappingsResponse;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.inject.util.Providers; import org.elasticsearch.common.inject.util.Providers;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
@ -48,11 +49,11 @@ public class IndexAuditTrailUpdateMappingTests extends SecurityIntegTestCase {
IndexNameResolver.Rollover rollover = randomFrom(HOURLY, DAILY, WEEKLY, MONTHLY); IndexNameResolver.Rollover rollover = randomFrom(HOURLY, DAILY, WEEKLY, MONTHLY);
Settings settings = Settings.builder().put("xpack.security.audit.index.rollover", rollover.name().toLowerCase(Locale.ENGLISH)) Settings settings = Settings.builder().put("xpack.security.audit.index.rollover", rollover.name().toLowerCase(Locale.ENGLISH))
.put("path.home", createTempDir()).build(); .put("path.home", createTempDir()).build();
Transport transport = mock(Transport.class); DiscoveryNode localNode = mock(DiscoveryNode.class);
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(new TransportAddress[] { LocalTransportAddress.buildUnique() }, when(localNode.getHostAddress()).thenReturn(LocalTransportAddress.buildUnique().toString());
LocalTransportAddress.buildUnique())); ClusterService clusterService = mock(ClusterService.class);
auditor = new IndexAuditTrail(settings, transport, Providers.of(internalClient()), threadPool, when(clusterService.localNode()).thenReturn(localNode);
mock(ClusterService.class)); auditor = new IndexAuditTrail(settings, Providers.of(internalClient()), threadPool, clusterService);
// before starting we add an event // before starting we add an event
auditor.authenticationFailed(new FakeRestRequest()); auditor.authenticationFailed(new FakeRestRequest());

View File

@ -7,6 +7,8 @@ package org.elasticsearch.xpack.security.audit.logfile;
import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.IndicesRequest;
import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.support.IndicesOptions;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.component.Lifecycle;
@ -41,9 +43,6 @@ import static org.hamcrest.Matchers.is;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
/**
*
*/
public class LoggingAuditTrailTests extends ESTestCase { public class LoggingAuditTrailTests extends ESTestCase {
private static enum RestContent { private static enum RestContent {
VALID() { VALID() {
@ -102,7 +101,8 @@ public class LoggingAuditTrailTests extends ESTestCase {
private String prefix; private String prefix;
private Settings settings; private Settings settings;
private Transport transport; private DiscoveryNode localNode;
private ClusterService clusterService;
private ThreadContext threadContext; private ThreadContext threadContext;
@Before @Before
@ -112,21 +112,20 @@ public class LoggingAuditTrailTests extends ESTestCase {
.put("xpack.security.audit.logfile.prefix.emit_node_host_name", randomBoolean()) .put("xpack.security.audit.logfile.prefix.emit_node_host_name", randomBoolean())
.put("xpack.security.audit.logfile.prefix.emit_node_name", randomBoolean()) .put("xpack.security.audit.logfile.prefix.emit_node_name", randomBoolean())
.build(); .build();
transport = mock(Transport.class); localNode = mock(DiscoveryNode.class);
when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(localNode.getHostAddress()).thenReturn(LocalTransportAddress.buildUnique().toString());
when(transport.boundAddress()).thenReturn(new BoundTransportAddress(new TransportAddress[] { LocalTransportAddress.buildUnique() }, clusterService = mock(ClusterService.class);
LocalTransportAddress.buildUnique())); when(clusterService.localNode()).thenReturn(localNode);
prefix = LoggingAuditTrail.resolvePrefix(settings, transport); prefix = LoggingAuditTrail.resolvePrefix(settings, localNode);
} }
public void testAnonymousAccessDeniedTransport() throws Exception { public void testAnonymousAccessDeniedTransport() throws Exception {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, clusterService.localNode(), threadContext);
auditTrail.anonymousAccessDenied("_action", message); auditTrail.anonymousAccessDenied("_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -164,8 +163,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.anonymousAccessDenied(request); auditTrail.anonymousAccessDenied(request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -188,10 +186,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext);; String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);;
auditTrail.authenticationFailed(new MockToken(), "_action", message); auditTrail.authenticationFailed(new MockToken(), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -222,10 +219,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext);; String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);;
auditTrail.authenticationFailed("_action", message); auditTrail.authenticationFailed("_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -261,8 +257,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.authenticationFailed(new MockToken(), request); auditTrail.authenticationFailed(new MockToken(), request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -289,8 +284,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.authenticationFailed(request); auditTrail.authenticationFailed(request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -311,10 +305,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext);; String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);;
auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message); auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -344,8 +337,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
when(request.uri()).thenReturn("_uri"); when(request.uri()).thenReturn("_uri");
String expectedMessage = prepareRestContent(request); String expectedMessage = prepareRestContent(request);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.authenticationFailed("_realm", new MockToken(), request); auditTrail.authenticationFailed("_realm", new MockToken(), request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -366,10 +358,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
boolean runAs = randomBoolean(); boolean runAs = randomBoolean();
User user; User user;
if (runAs) { if (runAs) {
@ -411,10 +402,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
auditTrail.accessGranted(SystemUser.INSTANCE, "internal:_action", message); auditTrail.accessGranted(SystemUser.INSTANCE, "internal:_action", message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -440,10 +430,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
boolean runAs = randomBoolean(); boolean runAs = randomBoolean();
User user; User user;
if (runAs) { if (runAs) {
@ -485,10 +474,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
boolean runAs = randomBoolean(); boolean runAs = randomBoolean();
User user; User user;
if (runAs) { if (runAs) {
@ -534,8 +522,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.tamperedRequest(request); auditTrail.tamperedRequest(request);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -557,10 +544,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.tamperedRequest(action, message); auditTrail.tamperedRequest(action, message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -599,10 +585,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext); TransportMessage message = randomBoolean() ? new MockMessage(threadContext) : new MockIndicesRequest(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
auditTrail.tamperedRequest(user, action, message); auditTrail.tamperedRequest(user, action, message);
switch (level) { switch (level) {
case ERROR: case ERROR:
@ -633,8 +618,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
InetAddress inetAddress = InetAddress.getLoopbackAddress(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
SecurityIpFilterRule rule = new SecurityIpFilterRule(false, "_all"); SecurityIpFilterRule rule = new SecurityIpFilterRule(false, "_all");
auditTrail.connectionDenied(inetAddress, "default", rule); auditTrail.connectionDenied(inetAddress, "default", rule);
@ -656,8 +640,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
InetAddress inetAddress = InetAddress.getLoopbackAddress(); InetAddress inetAddress = InetAddress.getLoopbackAddress();
SecurityIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; SecurityIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
auditTrail.connectionGranted(inetAddress, "default", rule); auditTrail.connectionGranted(inetAddress, "default", rule);
@ -680,10 +663,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = new MockMessage(threadContext); TransportMessage message = new MockMessage(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
User user = new User("_username", new String[]{"r1"}, new User("running as", new String[] {"r2"})); User user = new User("_username", new String[]{"r1"}, new User("running as", new String[] {"r2"}));
auditTrail.runAsGranted(user, "_action", message); auditTrail.runAsGranted(user, "_action", message);
switch (level) { switch (level) {
@ -707,10 +689,9 @@ public class LoggingAuditTrailTests extends ESTestCase {
for (Level level : Level.values()) { for (Level level : Level.values()) {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
CapturingLogger logger = new CapturingLogger(level); CapturingLogger logger = new CapturingLogger(level);
LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger, threadContext); LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, clusterService, logger, threadContext);
auditTrail.start();
TransportMessage message = new MockMessage(threadContext); TransportMessage message = new MockMessage(threadContext);
String origins = LoggingAuditTrail.originAttributes(message, transport, threadContext); String origins = LoggingAuditTrail.originAttributes(message, localNode, threadContext);
User user = new User("_username", new String[]{"r1"}, new User("running as", new String[] {"r2"})); User user = new User("_username", new String[]{"r1"}, new User("running as", new String[] {"r2"}));
auditTrail.runAsDenied(user, "_action", message); auditTrail.runAsDenied(user, "_action", message);
switch (level) { switch (level) {
@ -733,7 +714,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
public void testOriginAttributes() throws Exception { public void testOriginAttributes() throws Exception {
threadContext = new ThreadContext(Settings.EMPTY); threadContext = new ThreadContext(Settings.EMPTY);
MockMessage message = new MockMessage(threadContext); MockMessage message = new MockMessage(threadContext);
String text = LoggingAuditTrail.originAttributes(message, transport, threadContext);; String text = LoggingAuditTrail.originAttributes(message, localNode, threadContext);;
InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(threadContext); InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(threadContext);
if (restAddress != null) { if (restAddress != null) {
assertThat(text, equalTo("origin_type=[rest], origin_address=[" + assertThat(text, equalTo("origin_type=[rest], origin_address=[" +
@ -742,8 +723,7 @@ public class LoggingAuditTrailTests extends ESTestCase {
} }
TransportAddress address = message.remoteAddress(); TransportAddress address = message.remoteAddress();
if (address == null) { if (address == null) {
assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + assertThat(text, equalTo("origin_type=[local_node], origin_address=[" + localNode.getHostAddress() + "]"));
transport.boundAddress().publishAddress().getAddress() + "]"));
return; return;
} }