diff --git a/shield/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java b/shield/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java index db86c8d17c6..88e8b169787 100644 --- a/shield/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java +++ b/shield/src/main/java/org/elasticsearch/shield/audit/index/IndexAuditTrail.java @@ -8,6 +8,7 @@ package org.elasticsearch.shield.audit.index; import com.google.common.base.Splitter; import com.google.common.collect.ImmutableSet; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.admin.indices.template.put.PutIndexTemplateRequest; import org.elasticsearch.action.admin.indices.template.put.PutIndexTemplateResponse; import org.elasticsearch.action.bulk.BulkProcessor; @@ -103,10 +104,10 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { private final Environment environment; private final LinkedBlockingQueue eventQueue; private final QueueConsumer queueConsumer; + private final boolean indexToRemoteCluster; private BulkProcessor bulkProcessor; private Client client; - private boolean indexToRemoteCluster; private IndexNameResolver.Rollover rollover; private String nodeHostName; private String nodeHostAddress; @@ -155,6 +156,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { logger.warn("invalid event type specified, using default for audit index output. include events [{}], exclude events [{}]", e, includedEvents, excludedEvents); events = parse(DEFAULT_EVENT_INCLUDES, Strings.EMPTY_ARRAY); } + this.indexToRemoteCluster = settings.getByPrefix("shield.audit.index.client.").names().size() > 0; + } public State state() { @@ -163,7 +166,8 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { /** * This method determines if this service can be started based on the state in the {@link ClusterChangedEvent} and - * if the node is the master or not. In order for the service to start, the following must be true: + * if the node is the master or not. When using remote indexing, a call to the remote cluster will be made to retrieve + * the state and the same rules will be applied. In order for the service to start, the following must be true: * *
    *
  1. The cluster must not have a {@link GatewayService#STATE_NOT_RECOVERED_BLOCK}; in other words the gateway @@ -177,15 +181,32 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { * @param master flag indicating if the current node is the master * @return true if all requirements are met and the service can be started */ - public boolean canStart(ClusterChangedEvent event, boolean master) { - if (event.state().blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { + public synchronized boolean canStart(ClusterChangedEvent event, boolean master) { + if (indexToRemoteCluster) { + try { + if (client == null) { + initializeClient(); + } + } catch (Exception e) { + logger.error("failed to initialize client for remote indexing. index based output is disabled", e); + state.set(State.FAILED); + return false; + } + + ClusterStateResponse response = client.admin().cluster().prepareState().execute().actionGet(); + return canStart(response.getState(), master); + } + return canStart(event.state(), master); + } + + private boolean canStart(ClusterState clusterState, boolean master) { + if (clusterState.blocks().hasGlobalBlock(GatewayService.STATE_NOT_RECOVERED_BLOCK)) { // wait until the gateway has recovered from disk, otherwise we think may not have .shield-audit- // but they may not have been restored from the cluster state on disk logger.debug("index audit trail waiting until gateway has recovered from disk"); return false; } - final ClusterState clusterState = event.state(); if (!master && clusterState.metaData().templates().get(INDEX_TEMPLATE_NAME) == null) { logger.debug("shield audit index template [{}] does not exist, so service cannot start", INDEX_TEMPLATE_NAME); return false; @@ -227,7 +248,10 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { this.nodeHostName = hostname; this.nodeHostAddress = hostaddr; - initializeClient(); + if (client == null) { + initializeClient(); + } + if (master) { putTemplate(customAuditIndexSettings(settings)); } @@ -272,7 +296,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void anonymousAccessDenied(String action, TransportMessage message) { if (events.contains(ANONYMOUS_ACCESS_DENIED)) { try { - enqueue(message("anonymous_access_denied", action, null, null, indices(message), message)); + enqueue(message("anonymous_access_denied", action, null, null, indices(message), message), "anonymous_access_denied"); } catch (Exception e) { logger.warn("failed to index audit event: [anonymous_access_denied]", e); } @@ -283,7 +307,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void anonymousAccessDenied(RestRequest request) { if (events.contains(ANONYMOUS_ACCESS_DENIED)) { try { - enqueue(message("anonymous_access_denied", null, null, null, null, request)); + enqueue(message("anonymous_access_denied", null, null, null, null, request), "anonymous_access_denied"); } catch (Exception e) { logger.warn("failed to index audit event: [anonymous_access_denied]", e); } @@ -294,7 +318,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void authenticationFailed(String action, TransportMessage message) { if (events.contains(AUTHENTICATION_FAILED)) { try { - enqueue(message("authentication_failed", action, null, null, indices(message), message)); + enqueue(message("authentication_failed", action, null, null, indices(message), message), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -305,7 +329,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void authenticationFailed(RestRequest request) { if (events.contains(AUTHENTICATION_FAILED)) { try { - enqueue(message("authentication_failed", null, null, null, null, request)); + enqueue(message("authentication_failed", null, null, null, null, request), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -317,7 +341,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(AUTHENTICATION_FAILED)) { if (!principalIsAuditor(token.principal())) { try { - enqueue(message("authentication_failed", action, token.principal(), null, indices(message), message)); + enqueue(message("authentication_failed", action, token.principal(), null, indices(message), message), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -330,7 +354,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(AUTHENTICATION_FAILED)) { if (!principalIsAuditor(token.principal())) { try { - enqueue(message("authentication_failed", null, token.principal(), null, null, request)); + enqueue(message("authentication_failed", null, token.principal(), null, null, request), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -343,7 +367,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(AUTHENTICATION_FAILED)) { if (!principalIsAuditor(token.principal())) { try { - enqueue(message("authentication_failed", action, token.principal(), realm, indices(message), message)); + enqueue(message("authentication_failed", action, token.principal(), realm, indices(message), message), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -356,7 +380,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(AUTHENTICATION_FAILED)) { if (!principalIsAuditor(token.principal())) { try { - enqueue(message("authentication_failed", null, token.principal(), realm, null, request)); + enqueue(message("authentication_failed", null, token.principal(), realm, null, request), "authentication_failed"); } catch (Exception e) { logger.warn("failed to index audit event: [authentication_failed]", e); } @@ -371,14 +395,14 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (user.isSystem() && Privilege.SYSTEM.predicate().apply(action)) { if (events.contains(SYSTEM_ACCESS_GRANTED)) { try { - enqueue(message("access_granted", action, user.principal(), null, indices(message), message)); + enqueue(message("access_granted", action, user.principal(), null, indices(message), message), "access_granted"); } catch (Exception e) { logger.warn("failed to index audit event: [access_granted]", e); } } } else if (events.contains(ACCESS_GRANTED)) { try { - enqueue(message("access_granted", action, user.principal(), null, indices(message), message)); + enqueue(message("access_granted", action, user.principal(), null, indices(message), message), "access_granted"); } catch (Exception e) { logger.warn("failed to index audit event: [access_granted]", e); } @@ -391,7 +415,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(ACCESS_DENIED)) { if (!principalIsAuditor(user.principal())) { try { - enqueue(message("access_denied", action, user.principal(), null, indices(message), message)); + enqueue(message("access_denied", action, user.principal(), null, indices(message), message), "access_denied"); } catch (Exception e) { logger.warn("failed to index audit event: [access_denied]", e); } @@ -404,7 +428,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { if (events.contains(TAMPERED_REQUEST)) { if (!principalIsAuditor(user.principal())) { try { - enqueue(message("tampered_request", action, user.principal(), null, indices(request), request)); + enqueue(message("tampered_request", action, user.principal(), null, indices(request), request), "tampered_request"); } catch (Exception e) { logger.warn("failed to index audit event: [tampered_request]", e); } @@ -416,7 +440,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void connectionGranted(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) { if (events.contains(CONNECTION_GRANTED)) { try { - enqueue(message("ip_filter", "connection_granted", inetAddress, profile, rule)); + enqueue(message("ip_filter", "connection_granted", inetAddress, profile, rule), "connection_granted"); } catch (Exception e) { logger.warn("failed to index audit event: [connection_granted]", e); } @@ -427,7 +451,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { public void connectionDenied(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) { if (events.contains(CONNECTION_DENIED)) { try { - enqueue(message("ip_filter", "connection_denied", inetAddress, profile, rule)); + enqueue(message("ip_filter", "connection_denied", inetAddress, profile, rule), "connection_denied"); } catch (Exception e) { logger.warn("failed to index audit event: [connection_denied]", e); } @@ -545,25 +569,22 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { return builder; } - void enqueue(Message message) { + void enqueue(Message message, String type) { State currentState = state(); if (currentState != State.STOPPING && currentState != State.STOPPED) { boolean accepted = eventQueue.offer(message); if (!accepted) { - throw new IllegalStateException("queue is full, bulk processor may have stopped indexing"); + logger.warn("failed to index audit event: [{}]. queue is full; bulk processor may not be able to keep up or has stopped indexing.", type); } } } private void initializeClient() { - - Settings clientSettings = settings.getByPrefix("shield.audit.index.client."); - - if (clientSettings.names().size() == 0) { + if (indexToRemoteCluster == false) { // in the absence of client settings for remote indexing, fall back to the client that was passed in. this.client = clientProvider.get(); - indexToRemoteCluster = false; } else { + Settings clientSettings = settings.getByPrefix("shield.audit.index.client."); String[] hosts = clientSettings.getAsArray("hosts"); if (hosts.length == 0) { throw new ElasticsearchException("missing required setting " + @@ -591,7 +612,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { final TransportClient transportClient = TransportClient.builder() .settings(Settings.builder() - .put("name", DEFAULT_CLIENT_NAME) + .put("name", DEFAULT_CLIENT_NAME + "-" + settings.get("name")) .put("path.home", environment.homeFile()) .putArray("plugin.types", ShieldPlugin.class.getName()) .put(clientSettings)) @@ -601,8 +622,6 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { } this.client = transportClient; - indexToRemoteCluster = true; - logger.info("forwarding audit events to remote cluster [{}] using hosts [{}]", clientSettings.get("cluster.name", ""), hostPortPairs.toString()); } @@ -773,6 +792,7 @@ public class IndexAuditTrail extends AbstractComponent implements AuditTrail { STARTING, STARTED, STOPPING, - STOPPED + STOPPED, + FAILED } } diff --git a/shield/src/test/java/org/elasticsearch/shield/audit/index/RemoteIndexAuditTrailStartingTests.java b/shield/src/test/java/org/elasticsearch/shield/audit/index/RemoteIndexAuditTrailStartingTests.java new file mode 100644 index 00000000000..ef490606912 --- /dev/null +++ b/shield/src/test/java/org/elasticsearch/shield/audit/index/RemoteIndexAuditTrailStartingTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.shield.audit.index; + +import com.google.common.base.Predicate; +import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; +import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.transport.InetSocketTransportAddress; +import org.elasticsearch.test.InternalTestCluster; +import org.elasticsearch.test.ShieldIntegrationTest; +import org.elasticsearch.test.ShieldSettingsSource; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.test.InternalTestCluster.clusterName; +import static org.hamcrest.Matchers.is; + +/** + * This test checks to ensure that the IndexAuditTrail starts properly when indexing to a remote cluster + */ +public class RemoteIndexAuditTrailStartingTests extends ShieldIntegrationTest { + + public static final String SECOND_CLUSTER_NODE_PREFIX = "remote_" + SUITE_CLUSTER_NODE_PREFIX; + + private InternalTestCluster remoteCluster; + + private final boolean useSSL = randomBoolean(); + private final boolean localAudit = randomBoolean(); + private final String outputs = randomFrom("index", "logfile", "index,logfile"); + + @Override + public boolean sslTransportEnabled() { + return useSSL; + } + + @Override + public Settings nodeSettings(int nodeOrdinal) { + return Settings.builder() + .put(super.nodeSettings(nodeOrdinal)) + .put("shield.audit.enabled", localAudit) + .put("shield.audit.outputs", outputs) + .build(); + } + + @Before + public void startRemoteCluster() throws IOException { + final List addresses = new ArrayList<>(); + // get addresses for current cluster + NodesInfoResponse response = client().admin().cluster().prepareNodesInfo().execute().actionGet(); + final String clusterName = response.getClusterNameAsString(); + for (NodeInfo nodeInfo : response.getNodes()) { + InetSocketTransportAddress address = (InetSocketTransportAddress) nodeInfo.getTransport().address().publishAddress(); + addresses.add(address.address().getHostString() + ":" + address.address().getPort()); + } + + // create another cluster + String cluster2Name = clusterName(Scope.SUITE.name(), randomLong()); + + // Setup a second test cluster with randomization for number of nodes, shield enabled, and SSL + final int numNodes = randomIntBetween(2, 3); + ShieldSettingsSource cluster2SettingsSource = new ShieldSettingsSource(numNodes, useSSL, systemKey(), createTempDir(), Scope.SUITE) { + @Override + public Settings node(int nodeOrdinal) { + Settings.Builder builder = Settings.builder() + .put(super.node(nodeOrdinal)) + .put("shield.audit.enabled", true) + .put("shield.audit.outputs", randomFrom("index", "index,logfile")) + .putArray("shield.audit.index.client.hosts", addresses.toArray(new String[addresses.size()])) + .put("shield.audit.index.client.cluster.name", clusterName) + .put("shield.audit.index.client.shield.user", ShieldSettingsSource.DEFAULT_USER_NAME + ":" + ShieldSettingsSource.DEFAULT_PASSWORD); + + if (useSSL) { + for (Map.Entry entry : getClientSSLSettings().getAsMap().entrySet()) { + builder.put("shield.audit.index.client." + entry.getKey(), entry.getValue()); + } + } + return builder.build(); + } + }; + remoteCluster = new InternalTestCluster(randomLong(), createTempDir(), numNodes, numNodes, cluster2Name, cluster2SettingsSource, 0, false, SECOND_CLUSTER_NODE_PREFIX); + remoteCluster.beforeTest(getRandom(), 0.5); + } + + @After + public void stopRemoteCluster() throws Exception { + if (remoteCluster != null) { + try { + remoteCluster.wipe(); + } finally { + remoteCluster.afterTest(); + } + remoteCluster.close(); + } + } + + @Test + public void testThatRemoteAuditInstancesAreStarted() throws Exception { + Iterable auditTrails = remoteCluster.getInstances(IndexAuditTrail.class); + for (final IndexAuditTrail auditTrail : auditTrails) { + awaitBusy(new Predicate() { + @Override + public boolean apply(Void aVoid) { + return auditTrail.state() == IndexAuditTrail.State.STARTED; + } + }, 2L, TimeUnit.SECONDS); + assertThat(auditTrail.state(), is(IndexAuditTrail.State.STARTED)); + } + } + +}