diff --git a/src/main/java/org/elasticsearch/shield/ShieldPlugin.java b/src/main/java/org/elasticsearch/shield/ShieldPlugin.java index f95c5735b6f..0afb3a97c9d 100644 --- a/src/main/java/org/elasticsearch/shield/ShieldPlugin.java +++ b/src/main/java/org/elasticsearch/shield/ShieldPlugin.java @@ -14,6 +14,7 @@ import org.elasticsearch.common.inject.Module; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.env.Environment; import org.elasticsearch.plugins.AbstractPlugin; +import org.elasticsearch.shield.audit.index.IndexAuditTrailBulkProcessor; import org.elasticsearch.shield.authc.Realms; import org.elasticsearch.shield.authc.support.SecuredString; import org.elasticsearch.shield.authc.support.UsernamePasswordToken; @@ -26,6 +27,8 @@ import java.nio.file.Path; import java.util.Collection; import java.util.Map; +import static org.elasticsearch.shield.audit.AuditTrailModule.indexAuditLoggingEnabled; + /** * */ @@ -64,9 +67,15 @@ public class ShieldPlugin extends AbstractPlugin { @Override public Collection> services() { - return enabled && !clientMode ? - ImmutableList.>of(LicenseService.class, InternalCryptoService.class, FileRolesStore.class, Realms.class, IPFilter.class) : - ImmutableList.>of(); + ImmutableList.Builder> builder = ImmutableList.builder(); + if (enabled && !clientMode) { + if (indexAuditLoggingEnabled(settings)) { + // index-based audit logging should be started before other services + builder.add(IndexAuditTrailBulkProcessor.class); + } + builder.add(LicenseService.class).add(InternalCryptoService.class).add(FileRolesStore.class).add(Realms.class).add(IPFilter.class); + } + return builder.build(); } @Override diff --git a/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java b/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java index 2ba0c96df0d..ca4906ea4ec 100644 --- a/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java +++ b/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java @@ -7,9 +7,15 @@ package org.elasticsearch.shield.audit; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.common.collect.Sets; +import org.elasticsearch.common.inject.Module; +import org.elasticsearch.common.inject.PreProcessModule; import org.elasticsearch.common.inject.multibindings.Multibinder; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.shield.audit.index.IndexAuditTrail; +import org.elasticsearch.shield.audit.index.IndexAuditTrailBulkProcessor; +import org.elasticsearch.shield.audit.index.IndexAuditUserHolder; import org.elasticsearch.shield.audit.logfile.LoggingAuditTrail; +import org.elasticsearch.shield.authz.AuthorizationModule; import org.elasticsearch.shield.support.AbstractShieldModule; import java.util.Set; @@ -17,10 +23,12 @@ import java.util.Set; /** * */ -public class AuditTrailModule extends AbstractShieldModule.Node { +public class AuditTrailModule extends AbstractShieldModule.Node implements PreProcessModule { private final boolean enabled; + private IndexAuditUserHolder indexAuditUser; + public AuditTrailModule(Settings settings) { super(settings); enabled = settings.getAsBoolean("shield.audit.enabled", false); @@ -45,10 +53,37 @@ public class AuditTrailModule extends AbstractShieldModule.Node { switch (output) { case LoggingAuditTrail.NAME: binder.addBinding().to(LoggingAuditTrail.class); + bind(LoggingAuditTrail.class).asEagerSingleton(); + break; + case IndexAuditTrail.NAME: + bind(IndexAuditUserHolder.class).toInstance(indexAuditUser); + binder.addBinding().to(IndexAuditTrail.class); + bind(IndexAuditTrail.class).asEagerSingleton(); + bind(IndexAuditTrailBulkProcessor.class).asEagerSingleton(); break; default: throw new ElasticsearchException("unknown audit trail output [" + output + "]"); } } } + + @Override + public void processModule(Module module) { + if (enabled && module instanceof AuthorizationModule) { + if (indexAuditLoggingEnabled(settings)) { + indexAuditUser = new IndexAuditUserHolder(IndexAuditTrailBulkProcessor.INDEX_NAME_PREFIX); + ((AuthorizationModule) module).registerReservedRole(indexAuditUser.role()); + } + } + } + + public static boolean indexAuditLoggingEnabled(Settings settings) { + String[] outputs = settings.getAsArray("shield.audit.outputs"); + for (String output : outputs) { + if (output.equals(IndexAuditTrail.NAME)) { + return true; + } + } + return false; + } } diff --git a/src/main/java/org/elasticsearch/shield/audit/AuditUtil.java b/src/main/java/org/elasticsearch/shield/audit/AuditUtil.java new file mode 100644 index 00000000000..6211c0b07fc --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/audit/AuditUtil.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit; + +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.transport.TransportMessage; + +import java.io.IOException; + +/** + * + */ +public class AuditUtil { + + public static String restRequestContent(RestRequest request) { + if (request.hasContent()) { + try { + return XContentHelper.convertToJson(request.content(), false, false); + } catch (IOException ioe) { + return "Invalid Format: " + request.content().toUtf8(); + } + } + return ""; + } + + public static String indices(TransportMessage message) { + if (message instanceof IndicesRequest) { + return Strings.arrayToCommaDelimitedString(((IndicesRequest) message).indices()); + } + return null; + } +} diff --git a/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java new file mode 100644 index 00000000000..d1e2184d43e --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java @@ -0,0 +1,412 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.logging.ESLogger; +import org.elasticsearch.common.logging.ESLoggerFactory; +import org.elasticsearch.common.network.NetworkUtils; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.InetSocketTransportAddress; +import org.elasticsearch.common.transport.TransportAddress; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentBuilderString; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.shield.User; +import org.elasticsearch.shield.audit.AuditTrail; +import org.elasticsearch.shield.authc.AuthenticationToken; +import org.elasticsearch.shield.authz.Privilege; +import org.elasticsearch.shield.rest.RemoteHostHeader; +import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule; +import org.elasticsearch.transport.TransportMessage; +import org.elasticsearch.transport.TransportRequest; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.EnumSet; + +import static org.elasticsearch.shield.audit.AuditUtil.restRequestContent; +import static org.elasticsearch.shield.audit.AuditUtil.indices; + +/** + * Audit trail implementation that writes events into an index. + */ +public class IndexAuditTrail implements AuditTrail { + + private static final ESLogger logger = ESLoggerFactory.getLogger(IndexAuditTrail.class.getName()); + + public static final String NAME = "index"; + + private final String nodeName; + private final String nodeHostName; + private final String nodeHostAddress; + private final IndexAuditUserHolder auditUser; + private final IndexAuditTrailBulkProcessor processor; + + @Override + public String name() { + return NAME; + } + + private enum Level { + ANONYMOUS_ACCESS_DENIED, + AUTHENTICATION_FAILED, + ACCESS_GRANTED, + ACCESS_DENIED, + TAMPERED_REQUEST, + CONNECTION_GRANTED, + CONNECTION_DENIED, + SYSTEM_ACCESS_GRANTED; + } + + private final EnumSet enabled = EnumSet.allOf(Level.class); + + @Inject + public IndexAuditTrail(Settings settings, IndexAuditUserHolder indexingAuditUser, + IndexAuditTrailBulkProcessor processor) { + + this.auditUser = indexingAuditUser; + this.processor = processor; + this.nodeName = settings.get("name"); + + String hostname = "n/a"; + String hostaddr = "n/a"; + + try { + hostname = InetAddress.getLocalHost().getHostName(); + hostaddr = InetAddress.getLocalHost().getHostAddress(); + } catch (UnknownHostException e) { + logger.warn("unable to resolve local host name", e); + } + this.nodeHostName = hostname; + this.nodeHostAddress = hostaddr; + + if (!settings.getAsBoolean("shield.audit.index.events.system.access_granted", false)) { + enabled.remove(Level.SYSTEM_ACCESS_GRANTED); + } + if (!settings.getAsBoolean("shield.audit.index.events.anonymous_access_denied", true)) { + enabled.remove(Level.ANONYMOUS_ACCESS_DENIED); + } + if (!settings.getAsBoolean("shield.audit.index.events.authentication_failed", true)) { + enabled.remove(Level.AUTHENTICATION_FAILED); + } + if (!settings.getAsBoolean("shield.audit.index.events.access_granted", true)) { + enabled.remove(Level.ACCESS_GRANTED); + } + if (!settings.getAsBoolean("shield.audit.index.events.access_denied", true)) { + enabled.remove(Level.ACCESS_DENIED); + } + if (!settings.getAsBoolean("shield.audit.index.events.tampered_request", true)) { + enabled.remove(Level.TAMPERED_REQUEST); + } + if (!settings.getAsBoolean("shield.audit.index.events.connection_granted", true)) { + enabled.remove(Level.CONNECTION_GRANTED); + } + if (!settings.getAsBoolean("shield.audit.index.events.connection_denied", true)) { + enabled.remove(Level.CONNECTION_DENIED); + } + } + + @Override + public void anonymousAccessDenied(String action, TransportMessage message) { + if (enabled.contains(Level.ANONYMOUS_ACCESS_DENIED)) { + try { + processor.submit(message("anonymous_access_denied", action, null, null, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [anonymous_access_denied]", e); + } + } + } + + @Override + public void anonymousAccessDenied(RestRequest request) { + if (enabled.contains(Level.ANONYMOUS_ACCESS_DENIED)) { + try { + processor.submit(message("anonymous_access_denied", null, null, null, null, request)); + } catch (Exception e) { + logger.warn("failed to index audit event: [anonymous_access_denied]", e); + } + } + } + + @Override + public void authenticationFailed(AuthenticationToken token, String action, TransportMessage message) { + if (enabled.contains(Level.AUTHENTICATION_FAILED)) { + if (!principalIsAuditor(token.principal())) { + try { + processor.submit(message("authentication_failed", action, token.principal(), null, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [authentication_failed]", e); + } + } + } + } + + @Override + public void authenticationFailed(AuthenticationToken token, RestRequest request) { + if (enabled.contains(Level.AUTHENTICATION_FAILED)) { + if (!principalIsAuditor(token.principal())) { + try { + processor.submit(message("authentication_failed", null, token.principal(), null, null, request)); + } catch (Exception e) { + logger.warn("failed to index audit event: [authentication_failed]", e); + } + } + } + } + + @Override + public void authenticationFailed(String realm, AuthenticationToken token, String action, TransportMessage message) { + if (enabled.contains(Level.AUTHENTICATION_FAILED)) { + if (!principalIsAuditor(token.principal())) { + try { + processor.submit(message("authentication_failed", action, token.principal(), realm, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [authentication_failed]", e); + } + } + } + } + + @Override + public void authenticationFailed(String realm, AuthenticationToken token, RestRequest request) { + if (enabled.contains(Level.AUTHENTICATION_FAILED)) { + if (!principalIsAuditor(token.principal())) { + try { + processor.submit(message("authentication_failed", null, token.principal(), realm, null, request)); + } catch (Exception e) { + logger.warn("failed to index audit event: [authentication_failed]", e); + } + } + } + } + + @Override + public void accessGranted(User user, String action, TransportMessage message) { + if (enabled.contains(Level.ACCESS_GRANTED)) { + if (!principalIsAuditor(user.principal())) { + // special treatment for internal system actions - only log if explicitly told to + if (Privilege.SYSTEM.internalActionPredicate().apply(action)) { + if (enabled.contains(Level.SYSTEM_ACCESS_GRANTED)) { + try { + processor.submit(message("access_granted", action, user.principal(), null, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [access_granted]", e); + } + } + } else { + try { + processor.submit(message("access_granted", action, user.principal(), null, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [access_granted]", e); + } + } + } + } + } + + @Override + public void accessDenied(User user, String action, TransportMessage message) { + if (enabled.contains(Level.ACCESS_DENIED)) { + if (!principalIsAuditor(user.principal())) { + try { + processor.submit(message("access_denied", action, user.principal(), null, indices(message), message)); + } catch (Exception e) { + logger.warn("failed to index audit event: [access_denied]", e); + } + } + } + } + + @Override + public void tamperedRequest(User user, String action, TransportRequest request) { + if (enabled.contains(Level.TAMPERED_REQUEST)) { + if (!principalIsAuditor(user.principal())) { + try { + processor.submit(message("tampered_request", action, user.principal(), null, indices(request), request)); + } catch (Exception e) { + logger.warn("failed to index audit event: [tampered_request]", e); + } + } + } + } + + @Override + public void connectionGranted(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) { + if (enabled.contains(Level.CONNECTION_GRANTED)) { + try { + processor.submit(message("ip_filter", "connection_granted", inetAddress, profile, rule)); + } catch (Exception e) { + logger.warn("failed to index audit event: [connection_granted]", e); + } + } + } + + @Override + public void connectionDenied(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) { + if (enabled.contains(Level.CONNECTION_DENIED)) { + try { + processor.submit(message("ip_filter", "connection_denied", inetAddress, profile, rule)); + } catch (Exception e) { + logger.warn("failed to index audit event: [connection_denied]", e); + } + } + } + + private boolean principalIsAuditor(String principal) { + return (principal.equals(auditUser.user().principal())); + } + + private Message message(String type, @Nullable String action, @Nullable String principal, + @Nullable String realm, @Nullable String indices, TransportMessage message) throws Exception { + + Message msg = new Message().start(); + common("transport", type, msg.builder); + originAttributes(message, msg.builder); + + if (action != null) { + msg.builder.field(Field.ACTION, action); + } + if (principal != null) { + msg.builder.field(Field.PRINCIPAL, principal); + } + if (realm != null) { + msg.builder.field(Field.REALM, realm); + } + if (indices != null) { + msg.builder.field(Field.INDICES, indices); + } + if (logger.isDebugEnabled()) { + msg.builder.field(Field.REQUEST, message.getClass().getSimpleName()); + } + + return msg.end(); + } + + private Message message(String type, @Nullable String action, @Nullable String principal, + @Nullable String realm, @Nullable String indices, RestRequest request) throws Exception { + + Message msg = new Message().start(); + common("rest", type, msg.builder); + + if (action != null) { + msg.builder.field(Field.ACTION, action); + } + if (principal != null) { + msg.builder.field(Field.PRINCIPAL, principal); + } + if (realm != null) { + msg.builder.field(Field.REALM, realm); + } + if (indices != null) { + msg.builder.field(Field.INDICES, indices); + } + if (logger.isDebugEnabled()) { + msg.builder.field(Field.REQUEST_BODY, restRequestContent(request)); + } + + msg.builder.field(Field.ORIGIN_ADDRESS, request.getRemoteAddress()); + msg.builder.field(Field.URI, request.uri()); + + return msg.end(); + } + + private Message message(String layer, String type, InetAddress originAddress, String profile, + ShieldIpFilterRule rule) throws IOException { + + Message msg = new Message().start(); + common(layer, type, msg.builder); + + msg.builder.field(Field.ORIGIN_ADDRESS, originAddress.getHostAddress()); + msg.builder.field(Field.TRANSPORT_PROFILE, profile); + msg.builder.field(Field.RULE, rule); + + return msg.end(); + } + + private XContentBuilder common(String layer, String type, XContentBuilder builder) throws IOException { + builder.field(Field.NODE_NAME, nodeName); + builder.field(Field.NODE_HOST_NAME, nodeHostName); + builder.field(Field.NODE_HOST_ADDRESS, nodeHostAddress); + builder.field(Field.LAYER, layer); + builder.field(Field.TYPE, type); + return builder; + } + + private static XContentBuilder originAttributes(TransportMessage message, XContentBuilder builder) throws IOException { + + // first checking if the message originated in a rest call + InetSocketAddress restAddress = RemoteHostHeader.restRemoteAddress(message); + if (restAddress != null) { + builder.field(Field.ORIGIN_TYPE, "rest"); + builder.field(Field.ORIGIN_ADDRESS, restAddress); + return builder; + } + + // we'll see if was originated in a remote node + TransportAddress address = message.remoteAddress(); + if (address != null) { + builder.field(Field.ORIGIN_TYPE, "transport"); + if (address instanceof InetSocketTransportAddress) { + builder.field(Field.ORIGIN_ADDRESS, ((InetSocketTransportAddress) address).address()); + } else { + builder.field(Field.ORIGIN_ADDRESS, address); + } + return builder; + } + + // the call was originated locally on this node + builder.field(Field.ORIGIN_TYPE, "local_node"); + builder.field(Field.ORIGIN_ADDRESS, NetworkUtils.getLocalHostAddress("_local")); + return builder; + } + + static class Message { + + final long timestamp; + final XContentBuilder builder; + + Message() throws IOException { + this.timestamp = System.currentTimeMillis(); + this.builder = XContentFactory.jsonBuilder(); + } + + Message start() throws IOException { + builder.startObject(); + builder.field(Field.TIMESTAMP, timestamp); + return this; + } + + Message end() throws IOException { + builder.endObject(); + return this; + } + } + + interface Field { + XContentBuilderString TIMESTAMP = new XContentBuilderString("timestamp"); + XContentBuilderString NODE_NAME = new XContentBuilderString("node_name"); + XContentBuilderString NODE_HOST_NAME = new XContentBuilderString("node_host_name"); + XContentBuilderString NODE_HOST_ADDRESS = new XContentBuilderString("node_host_address"); + XContentBuilderString LAYER = new XContentBuilderString("layer"); + XContentBuilderString TYPE = new XContentBuilderString("type"); + XContentBuilderString ORIGIN_ADDRESS = new XContentBuilderString("origin_address"); + XContentBuilderString ORIGIN_TYPE = new XContentBuilderString("origin_type"); + XContentBuilderString PRINCIPAL = new XContentBuilderString("principal"); + XContentBuilderString ACTION = new XContentBuilderString("action"); + XContentBuilderString INDICES = new XContentBuilderString("indices"); + XContentBuilderString REQUEST = new XContentBuilderString("request"); + XContentBuilderString REQUEST_BODY = new XContentBuilderString("request_body"); + XContentBuilderString URI = new XContentBuilderString("uri"); + XContentBuilderString REALM = new XContentBuilderString("realm"); + XContentBuilderString TRANSPORT_PROFILE = new XContentBuilderString("transport_profile"); + XContentBuilderString RULE = new XContentBuilderString("rule"); + } +} diff --git a/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrailBulkProcessor.java b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrailBulkProcessor.java new file mode 100644 index 00000000000..2461e70717c --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrailBulkProcessor.java @@ -0,0 +1,189 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.bulk.BulkProcessor; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.client.Client; +import org.elasticsearch.client.transport.TransportClient; +import org.elasticsearch.common.base.Splitter; +import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.inject.Provider; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.InetSocketTransportAddress; +import org.elasticsearch.common.unit.TimeValue; +import org.elasticsearch.env.Environment; +import org.elasticsearch.shield.authc.AuthenticationService; + +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +/** + * + */ +public class IndexAuditTrailBulkProcessor extends AbstractLifecycleComponent { + + public static final int DEFAULT_BULK_SIZE = 1000; + public static final int MAX_BULK_SIZE = 10000; + public static final String INDEX_NAME_PREFIX = ".shield-audit-log"; + public static final String DOC_TYPE = "event"; + + public static final TimeValue DEFAULT_FLUSH_INTERVAL = TimeValue.timeValueSeconds(1); + public static final IndexNameResolver.Rollover DEFAULT_ROLLOVER = IndexNameResolver.Rollover.DAILY; + + private static BulkProcessor bulkProcessor; + + private final Provider clientProvider; + private final IndexAuditUserHolder auditUser; + private final AuthenticationService authenticationService; + private final IndexNameResolver resolver; + private final IndexNameResolver.Rollover rollover; + private final Environment environment; + + private Client client; + private boolean indexToRemoteCluster; + + @Inject + public IndexAuditTrailBulkProcessor(Settings settings, Environment environment, AuthenticationService authenticationService, + IndexAuditUserHolder auditUser, Provider clientProvider) { + super(settings); + this.authenticationService = authenticationService; + this.auditUser = auditUser; + this.clientProvider = clientProvider; + this.environment = environment; + + IndexNameResolver.Rollover rollover; + try { + rollover = IndexNameResolver.Rollover.valueOf( + settings.get("shield.audit.index.rollover", DEFAULT_ROLLOVER.name()).toUpperCase(Locale.ENGLISH)); + } catch (IllegalArgumentException e) { + logger.warn("invalid value for setting [shield.audit.index.rollover]; falling back to default [{}]", + DEFAULT_ROLLOVER.name()); + rollover = DEFAULT_ROLLOVER; + } + this.rollover = rollover; + this.resolver = new IndexNameResolver(); + } + + @Override + protected void doStart() throws ElasticsearchException { + initializeClient(); + initializeBulkProcessor(); + } + + @Override + protected void doStop() throws ElasticsearchException { + } + + @Override + protected void doClose() throws ElasticsearchException { + try { + if (bulkProcessor != null) { + bulkProcessor.close(); + } + } finally { + if (indexToRemoteCluster) { + if (client != null) { + client.close(); + } + } + } + } + + public void submit(IndexAuditTrail.Message message) { + assert lifecycle.started(); + IndexRequest indexRequest = client.prepareIndex() + .setIndex(resolver.resolve(INDEX_NAME_PREFIX, message.timestamp, rollover)) + .setType(DOC_TYPE).setSource(message.builder).request(); + authenticationService.attachUserHeaderIfMissing(indexRequest, auditUser.user()); + bulkProcessor.add(indexRequest); + } + + private void initializeClient() { + + Settings clientSettings = settings.getByPrefix("shield.audit.index.client."); + + if (clientSettings.names().size() == 0) { + // in the absence of client settings for remote indexing, fall back to the client that was passed in. + this.client = clientProvider.get(); + indexToRemoteCluster = false; + } else { + String[] hosts = clientSettings.getAsArray("hosts"); + if (hosts.length == 0) { + throw new ElasticsearchException("missing required setting " + + "[shield.audit.index.client.hosts] for remote audit log indexing"); + } + + if (clientSettings.get("cluster.name", "").isEmpty()) { + throw new ElasticsearchException("missing required setting " + + "[shield.audit.index.client.cluster.name] for remote audit log indexing"); + } + + List> hostPortPairs = new ArrayList<>(); + + for (String host : hosts) { + List hostPort = Splitter.on(":").splitToList(host.trim()); + if (hostPort.size() != 1 && hostPort.size() != 2) { + logger.warn("invalid host:port specified: [{}] for setting [shield.audit.index.client.hosts]", host); + } + hostPortPairs.add(new Tuple<>(hostPort.get(0), hostPort.size() == 2 ? Integer.valueOf(hostPort.get(1)) : 9300)); + } + + if (hostPortPairs.size() == 0) { + throw new ElasticsearchException("no valid host:port pairs specified for setting [shield.audit.index.client.hosts]"); + } + + final TransportClient transportClient = TransportClient.builder() + .settings(Settings.builder().put(clientSettings).put("path.home", environment.homeFile()).build()).build(); + for (Tuple pair : hostPortPairs) { + transportClient.addTransportAddress(new InetSocketTransportAddress(pair.v1(), pair.v2())); + } + + this.client = transportClient; + indexToRemoteCluster = true; + + logger.info("forwarding audit events to remote cluster [{}] using hosts [{}]", + clientSettings.get("cluster.name", ""), hostPortPairs.toString()); + } + } + + private void initializeBulkProcessor() { + + int bulkSize = Math.min(settings.getAsInt("shield.audit.index.bulk_size", DEFAULT_BULK_SIZE), MAX_BULK_SIZE); + bulkSize = (bulkSize < 1) ? DEFAULT_BULK_SIZE : bulkSize; + + TimeValue interval = settings.getAsTime("shield.audit.index.flush_interval", DEFAULT_FLUSH_INTERVAL); + interval = (interval.millis() < 1) ? DEFAULT_FLUSH_INTERVAL : interval; + + bulkProcessor = BulkProcessor.builder(client, new BulkProcessor.Listener() { + @Override + public void beforeBulk(long executionId, BulkRequest request) { + authenticationService.attachUserHeaderIfMissing(request, auditUser.user()); + } + + @Override + public void afterBulk(long executionId, BulkRequest request, BulkResponse response) { + if (response.hasFailures()) { + logger.info("failed to bulk index audit events: [{}]", response.buildFailureMessage()); + } + } + + @Override + public void afterBulk(long executionId, BulkRequest request, Throwable failure) { + logger.error("failed to bulk index audit events: [{}]", failure, failure.getMessage()); + } + }).setBulkActions(bulkSize) + .setFlushInterval(interval) + .setConcurrentRequests(1) + .build(); + } +} diff --git a/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditUserHolder.java b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditUserHolder.java new file mode 100644 index 00000000000..1c24f5b63bf --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditUserHolder.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.shield.User; +import org.elasticsearch.shield.authz.Permission; +import org.elasticsearch.shield.authz.Privilege; + +/** + * + */ +public class IndexAuditUserHolder { + + private static final String NAME = "__indexing_audit_user"; + private static final String[] ROLE_NAMES = new String[] { "__indexing_audit_role" }; + + private final User user; + private final Permission.Global.Role role; + + public IndexAuditUserHolder(String indexName) { + + // append the index name with the '*' wildcard so that the principal can write to + // any index that starts with the given name. this allows us to rollover over + // audit indices hourly, daily, weekly, etc. + String indexPattern = indexName + "*"; + + this.role = Permission.Global.Role.builder(ROLE_NAMES[0]) + .add(Privilege.Index.CREATE_INDEX, indexPattern) + .add(Privilege.Index.INDEX, indexPattern) + .add(Privilege.Index.action(BulkAction.NAME), indexPattern) + .build(); + + this.user = new User.Simple(NAME, ROLE_NAMES); + } + + public User user() { + return user; + } + + public Permission.Global.Role role() { + return role; + } +} diff --git a/src/main/java/org/elasticsearch/shield/audit/index/IndexNameResolver.java b/src/main/java/org/elasticsearch/shield/audit/index/IndexNameResolver.java new file mode 100644 index 00000000000..b5542694d25 --- /dev/null +++ b/src/main/java/org/elasticsearch/shield/audit/index/IndexNameResolver.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; + +/** + * + */ +public class IndexNameResolver { + + private final DateFormat formatter = DateFormat.getDateTimeInstance(DateFormat.LONG, DateFormat.LONG, Locale.ROOT); + + public enum Rollover { + HOURLY ("-yyyy-MM-dd-HH"), + DAILY ("-yyyy-MM-dd"), + WEEKLY ("-yyyy-w"), + MONTHLY ("-yyyy-MM"); + + private final String format; + + Rollover(String format) { + this.format = format; + } + } + + public String resolve(long timestamp, Rollover rollover) { + Date date = new Date(timestamp); + ((SimpleDateFormat) formatter).applyPattern(rollover.format); + return formatter.format(date); + } + + public String resolve(String indexNamePrefix, long timestamp, Rollover rollover) { + return indexNamePrefix + resolve(timestamp, rollover); + } +} diff --git a/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java b/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java index 10a32eb0d3e..93885cc7e95 100644 --- a/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java +++ b/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java @@ -30,6 +30,9 @@ import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.UnknownHostException; +import static org.elasticsearch.shield.audit.AuditUtil.restRequestContent; +import static org.elasticsearch.shield.audit.AuditUtil.indices; + /** * */ @@ -215,24 +218,6 @@ public class LoggingAuditTrail implements AuditTrail { logger.error("{}[ip_filter] [connection_denied]\torigin_address=[{}], transport_profile=[{}], rule=[{}]", prefix, inetAddress.getHostAddress(), profile, rule); } - private static String indices(TransportMessage message) { - if (message instanceof IndicesRequest) { - return Strings.arrayToCommaDelimitedString(((IndicesRequest) message).indices()); - } - return null; - } - - private static String restRequestContent(RestRequest request) { - if (request.hasContent()) { - try { - return XContentHelper.convertToJson(request.content(), false, false); - } catch (IOException ioe) { - return "Invalid Format: " + request.content().toUtf8(); - } - } - return ""; - } - private static String hostAttributes(RestRequest request) { return "origin_address=[" + request.getRemoteAddress() + "]"; } diff --git a/src/main/java/org/elasticsearch/shield/authz/Privilege.java b/src/main/java/org/elasticsearch/shield/authz/Privilege.java index 1b4fa457e5b..9b36742b4c1 100644 --- a/src/main/java/org/elasticsearch/shield/authz/Privilege.java +++ b/src/main/java/org/elasticsearch/shield/authz/Privilege.java @@ -213,7 +213,7 @@ public abstract class Privilege

> { public static void addCustom(String name, String... actionPatterns) { for (String pattern : actionPatterns) { if (!Index.ACTION_MATCHER.apply(pattern)) { - throw new ShieldException("cannot register custom index privilege [" + name + "]. index aciton must follow the 'indices:*' format"); + throw new ShieldException("cannot register custom index privilege [" + name + "]. index action must follow the 'indices:*' format"); } } Index custom = new Index(name, actionPatterns); diff --git a/src/main/java/org/elasticsearch/shield/client/ShieldClient.java b/src/main/java/org/elasticsearch/shield/client/ShieldClient.java index cb81e2c5bc8..65e790af811 100644 --- a/src/main/java/org/elasticsearch/shield/client/ShieldClient.java +++ b/src/main/java/org/elasticsearch/shield/client/ShieldClient.java @@ -21,7 +21,7 @@ public class ShieldClient { } /** - * @return The Shield authenticatin client. + * @return The Shield authentication client. */ public ShieldAuthcClient authc() { return authcClient; diff --git a/src/test/java/org/elasticsearch/shield/audit/index/IndexAuditTrailTests.java b/src/test/java/org/elasticsearch/shield/audit/index/IndexAuditTrailTests.java new file mode 100644 index 00000000000..ae18d0b8dc6 --- /dev/null +++ b/src/test/java/org/elasticsearch/shield/audit/index/IndexAuditTrailTests.java @@ -0,0 +1,530 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.elasticsearch.action.admin.indices.delete.DeleteIndexResponse; +import org.elasticsearch.action.exists.ExistsResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.client.Client; +import org.elasticsearch.common.base.Predicate; +import org.elasticsearch.common.inject.util.Providers; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.InetSocketTransportAddress; +import org.elasticsearch.common.transport.LocalTransportAddress; +import org.elasticsearch.env.Environment; +import org.elasticsearch.indices.IndexMissingException; +import org.elasticsearch.node.Node; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.shield.User; +import org.elasticsearch.shield.authc.AuthenticationService; +import org.elasticsearch.shield.authc.AuthenticationToken; +import org.elasticsearch.shield.transport.filter.IPFilter; +import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule; +import org.elasticsearch.test.ElasticsearchIntegrationTest; +import org.elasticsearch.test.ShieldIntegrationTest; +import org.elasticsearch.transport.TransportInfo; +import org.elasticsearch.transport.TransportMessage; +import org.elasticsearch.transport.TransportRequest; + +import org.junit.After; +import org.junit.Test; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Locale; + +import static org.elasticsearch.node.NodeBuilder.*; +import static org.elasticsearch.test.ElasticsearchIntegrationTest.Scope.*; +import static org.elasticsearch.shield.audit.index.IndexNameResolver.Rollover.*; +import static org.hamcrest.Matchers.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * + */ +@ElasticsearchIntegrationTest.ClusterScope(scope = SUITE, numDataNodes = 1) +public class IndexAuditTrailTests extends ShieldIntegrationTest { + + private IndexNameResolver resolver = new IndexNameResolver(); + private IndexNameResolver.Rollover rollover; + private IndexAuditTrailBulkProcessor bulkProcessor; + private IndexAuditTrail auditor; + + private boolean remoteIndexing = false; + private Node remoteNode; + private Client remoteClient; + + public static final String REMOTE_TEST_CLUSTER = "single-node-remote-test-cluster"; + + private static final IndexAuditUserHolder user = new IndexAuditUserHolder(IndexAuditTrailBulkProcessor.INDEX_NAME_PREFIX); + + private Settings commonSettings(IndexNameResolver.Rollover rollover) { + return Settings.builder() + .put("shield.audit.enabled", true) + .put("shield.audit.outputs", "index, logfile") + .put("shield.audit.index.bulk_size", 1) + .put("shield.audit.index.flush_interval", "1ms") + .put("shield.audit.index.rollover", rollover.name().toLowerCase(Locale.ENGLISH)) + .build(); + } + + private Settings remoteSettings(String address, int port, String clusterName) { + return Settings.builder() + .put("shield.audit.index.client.hosts", address + ":" + port) + .put("shield.audit.index.client.cluster.name", clusterName) + .build(); + } + + private Settings mutedSettings(String... muted) { + Settings.Builder builder = Settings.builder(); + for (String mute : muted) { + builder.put("shield.audit.index.events." + mute, false); + } + return builder.build(); + } + + private Settings settings(IndexNameResolver.Rollover rollover, String... muted) { + Settings.Builder builder = Settings.builder(); + builder.put(mutedSettings(muted)); + builder.put(commonSettings(rollover)); + return builder.build(); + } + + private IndexAuditTrailBulkProcessor buildIndexAuditTrailService(Settings settings) { + + AuthenticationService authService = mock(AuthenticationService.class); + when(authService.authenticate(mock(RestRequest.class))).thenThrow(UnsupportedOperationException.class); + when(authService.authenticate("_action", new LocalHostMockMessage(), user.user())).thenThrow(UnsupportedOperationException.class); + + Environment env = new Environment(settings); + return new IndexAuditTrailBulkProcessor(settings, env, authService, user, Providers.of(client())); + } + + private Client getClient() { + return remoteIndexing ? remoteClient : client(); + } + + private void initialize(String... muted) { + + rollover = randomFrom(HOURLY, DAILY, WEEKLY, MONTHLY); + Settings settings = settings(rollover, muted); + remoteIndexing = randomBoolean(); + + if (remoteIndexing) { + // start a small single-node cluster to test remote indexing against + logger.info("--> remote indexing enabled"); + Settings s = Settings.builder().put("shield.enabled", "false").put("path.home", createTempDir()).build(); + remoteNode = nodeBuilder().clusterName(REMOTE_TEST_CLUSTER).data(true).settings(s).node(); + remoteClient = remoteNode.client(); + + NodesInfoResponse response = remoteClient.admin().cluster().prepareNodesInfo().execute().actionGet(); + TransportInfo info = response.getNodes()[0].getTransport(); + InetSocketTransportAddress inet = (InetSocketTransportAddress) info.address().publishAddress(); + + settings = Settings.builder() + .put(settings) + .put(remoteSettings(inet.address().getAddress().getHostAddress(), inet.address().getPort(), REMOTE_TEST_CLUSTER)) + .build(); + } + + Settings settings1 = Settings.builder().put(settings).put("path.home", createTempDir()).build(); + logger.info("--> settings: [{}]", settings.getAsMap().toString()); + bulkProcessor = buildIndexAuditTrailService(settings1); + bulkProcessor.start(); + auditor = new IndexAuditTrail(settings, user, bulkProcessor); + } + + @After + public void afterTest() { + bulkProcessor.close(); + cluster().wipe(); + if (remoteIndexing && remoteNode != null) { + DeleteIndexResponse response = remoteClient.admin().indices().prepareDelete("*").execute().actionGet(); + assertTrue(response.isAcknowledged()); + remoteClient.close(); + remoteNode.close(); + } + } + + @Test + public void testAnonymousAccessDenied_Transport() throws Exception { + + initialize(); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.anonymousAccessDenied("_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + assertAuditMessage(hit, "transport", "anonymous_access_denied"); + + if (message instanceof RemoteHostMockMessage) { + assertEquals("remote_host:1234", hit.field("origin_address").getValue()); + } else { + assertEquals("local[local_host]", hit.field("origin_address").getValue()); + } + + assertEquals("_action", hit.field("action").getValue()); + assertEquals("transport", hit.field("origin_type").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAnonymousAccessDenied_Transport_Muted() throws Exception { + initialize("anonymous_access_denied"); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.anonymousAccessDenied("_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAnonymousAccessDenied_Rest() throws Exception { + + initialize(); + RestRequest request = mockRestRequest(); + auditor.anonymousAccessDenied(request); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "rest", "anonymous_access_denied"); + assertThat("_hostname:9200", equalTo(hit.field("origin_address").getValue())); + assertThat("_uri", equalTo(hit.field("uri").getValue())); + } + + @Test(expected = IndexMissingException.class) + public void testAnonymousAccessDenied_Rest_Muted() throws Exception { + initialize("anonymous_access_denied"); + RestRequest request = mockRestRequest(); + auditor.anonymousAccessDenied(request); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAuthenticationFailed_Transport() throws Exception { + + initialize(); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.authenticationFailed(new MockToken(), "_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "transport", "authentication_failed"); + + if (message instanceof RemoteHostMockMessage) { + assertEquals("remote_host:1234", hit.field("origin_address").getValue()); + } else { + assertEquals("local[local_host]", hit.field("origin_address").getValue()); + } + + assertEquals("_principal", hit.field("principal").getValue()); + assertEquals("_action", hit.field("action").getValue()); + assertEquals("transport", hit.field("origin_type").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAuthenticationFailed_Transport_Muted() throws Exception { + initialize("authentication_failed"); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.authenticationFailed(new MockToken(), "_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAuthenticationFailed_Rest() throws Exception { + + initialize(); + RestRequest request = mockRestRequest(); + auditor.authenticationFailed(new MockToken(), request); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "rest", "authentication_failed"); + assertThat("_hostname:9200", equalTo(hit.field("origin_address").getValue())); + assertThat("_uri", equalTo(hit.field("uri").getValue())); + } + + @Test(expected = IndexMissingException.class) + public void testAuthenticationFailed_Rest_Muted() throws Exception { + initialize("authentication_failed"); + RestRequest request = mockRestRequest(); + auditor.authenticationFailed(new MockToken(), request); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAuthenticationFailed_Transport_Realm() throws Exception { + + initialize(); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.authenticationFailed("_realm", new MockToken(), "_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "transport", "authentication_failed"); + + if (message instanceof RemoteHostMockMessage) { + assertEquals("remote_host:1234", hit.field("origin_address").getValue()); + } else { + assertEquals("local[local_host]", hit.field("origin_address").getValue()); + } + + assertEquals("transport", hit.field("origin_type").getValue()); + assertEquals("_principal", hit.field("principal").getValue()); + assertEquals("_action", hit.field("action").getValue()); + assertEquals("_realm", hit.field("realm").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAuthenticationFailed_Transport_Realm_Muted() throws Exception { + initialize("authentication_failed"); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.authenticationFailed("_realm", new MockToken(), "_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAuthenticationFailed_Rest_Realm() throws Exception { + + initialize(); + RestRequest request = mockRestRequest(); + auditor.authenticationFailed("_realm", new MockToken(), request); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "rest", "authentication_failed"); + assertThat("_hostname:9200", equalTo(hit.field("origin_address").getValue())); + assertThat("_uri", equalTo(hit.field("uri").getValue())); + assertEquals("_realm", hit.field("realm").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAuthenticationFailed_Rest_Realm_Muted() throws Exception { + initialize("authentication_failed"); + RestRequest request = mockRestRequest(); + auditor.authenticationFailed("_realm", new MockToken(), request); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAccessGranted() throws Exception { + + initialize(); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.accessGranted(new User.Simple("_username", "r1"), "_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + assertAuditMessage(hit, "transport", "access_granted"); + assertEquals("transport", hit.field("origin_type").getValue()); + assertEquals("_username", hit.field("principal").getValue()); + assertEquals("_action", hit.field("action").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAccessGranted_Muted() throws Exception { + initialize("access_granted"); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.accessGranted(new User.Simple("_username", "r1"), "_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testAccessDenied() throws Exception { + + initialize(); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.accessDenied(new User.Simple("_username", "r1"), "_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + assertAuditMessage(hit, "transport", "access_denied"); + assertEquals("transport", hit.field("origin_type").getValue()); + assertEquals("_username", hit.field("principal").getValue()); + assertEquals("_action", hit.field("action").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testAccessDenied_Muted() throws Exception { + initialize("access_denied"); + TransportMessage message = randomBoolean() ? new RemoteHostMockMessage() : new LocalHostMockMessage(); + auditor.accessDenied(new User.Simple("_username", "r1"), "_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testTamperedRequest() throws Exception { + + initialize(); + TransportRequest message = new RemoteHostMockTransportRequest(); + auditor.tamperedRequest(new User.Simple("_username", "r1"), "_action", message); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "transport", "tampered_request"); + assertEquals("transport", hit.field("origin_type").getValue()); + assertEquals("_username", hit.field("principal").getValue()); + assertEquals("_action", hit.field("action").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testTamperedRequest_Muted() throws Exception { + initialize("tampered_request"); + TransportRequest message = new RemoteHostMockTransportRequest(); + auditor.tamperedRequest(new User.Simple("_username", "r1"), "_action", message); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testConnectionGranted() throws Exception { + + initialize(); + InetAddress inetAddress = InetAddress.getLocalHost(); + ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; + auditor.connectionGranted(inetAddress, "default", rule); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "ip_filter", "connection_granted"); + assertEquals("allow default:accept_all", hit.field("rule").getValue()); + assertEquals("default", hit.field("transport_profile").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testConnectionGranted_Muted() throws Exception { + initialize("connection_granted"); + InetAddress inetAddress = InetAddress.getLocalHost(); + ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; + auditor.connectionGranted(inetAddress, "default", rule); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + @Test + public void testConnectionDenied() throws Exception { + + initialize(); + InetAddress inetAddress = InetAddress.getLocalHost(); + ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); + auditor.connectionDenied(inetAddress, "default", rule); + awaitIndexCreation(resolveIndexName()); + + SearchHit hit = getIndexedAuditMessage(); + + assertAuditMessage(hit, "ip_filter", "connection_denied"); + assertEquals("deny _all", hit.field("rule").getValue()); + assertEquals("default", hit.field("transport_profile").getValue()); + } + + @Test(expected = IndexMissingException.class) + public void testConnectionDenied_Muted() throws Exception { + initialize("connection_denied"); + InetAddress inetAddress = InetAddress.getLocalHost(); + ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); + auditor.connectionDenied(inetAddress, "default", rule); + getClient().prepareExists(resolveIndexName()).execute().actionGet(); + } + + private void assertAuditMessage(SearchHit hit, String layer, String type) { + + assertThat((Long) hit.field("timestamp").getValue(), greaterThan(0L)); + assertThat((Long) hit.field("timestamp").getValue(), lessThan(System.currentTimeMillis())); + + assertThat(clusterService().localNode().getHostName(), equalTo(hit.field("node_host_name").getValue())); + assertThat(clusterService().localNode().getHostAddress(), equalTo(hit.field("node_host_address").getValue())); + + assertEquals(layer, hit.field("layer").getValue()); + assertEquals(type, hit.field("type").getValue()); + } + + private static class LocalHostMockMessage extends TransportMessage { + LocalHostMockMessage() { + remoteAddress(new LocalTransportAddress("local_host")); + } + } + + private static class RemoteHostMockMessage extends TransportMessage { + RemoteHostMockMessage() { + remoteAddress(new InetSocketTransportAddress("remote_host", 1234)); + } + } + + private static class RemoteHostMockTransportRequest extends TransportRequest { + RemoteHostMockTransportRequest() { + remoteAddress(new InetSocketTransportAddress("remote_host", 1234)); + } + } + + private static class MockToken implements AuthenticationToken { + @Override + public String principal() { + return "_principal"; + } + + @Override + public Object credentials() { + fail("it's not allowed to print the credentials of the auth token"); + return null; + } + + @Override + public void clearCredentials() { + } + } + + private RestRequest mockRestRequest() { + RestRequest request = mock(RestRequest.class); + when(request.getRemoteAddress()).thenReturn(new InetSocketAddress("_hostname", 9200)); + when(request.uri()).thenReturn("_uri"); + return request; + } + + private SearchHit getIndexedAuditMessage() { + + SearchResponse response = getClient().prepareSearch(resolveIndexName()) + .setTypes(IndexAuditTrailBulkProcessor.DOC_TYPE) + .addFields(fieldList()) + .execute().actionGet(); + + assertEquals(1, response.getHits().getTotalHits()); + return response.getHits().getHits()[0]; + } + + private String[] fieldList() { + java.lang.reflect.Field[] fields = IndexAuditTrail.Field.class.getDeclaredFields(); + String[] array = new String[fields.length]; + for (int i = 0; i < fields.length; i++) { + array[i] = fields[i].getName().toLowerCase(Locale.ROOT); + } + return array; + } + + private void awaitIndexCreation(final String indexName) throws InterruptedException { + awaitBusy(new Predicate() { + @Override + public boolean apply(Void o) { + try { + ExistsResponse response = + getClient().prepareExists(indexName).execute().actionGet(); + return response.exists(); + } catch (Exception e) { + return false; + } + } + }); + } + + private String resolveIndexName() { + return resolver.resolve(IndexAuditTrailBulkProcessor.INDEX_NAME_PREFIX, System.currentTimeMillis(), rollover); + } +} +