From fc321a02a128e9871807deed07085f1dea667604 Mon Sep 17 00:00:00 2001 From: jaymode Date: Wed, 9 Dec 2015 07:11:17 -0500 Subject: [PATCH] fix logging audit trail to not cause guice issues When the logging audit trail is configured to add the node hostname or ip address as a prefix, the logging audit trail can invoke guice dependency injection issues since the transport that is injected is a proxy. This change makes the logging audit trail a lifecycle component and waits for the transport to be started before initializing the prefix. Closes elastic/elasticsearch#1104 Original commit: elastic/x-pack-elasticsearch@3b192839692c5ef380fe682a5f8252294600a8eb --- .../elasticsearch/shield/ShieldPlugin.java | 14 +++++-- .../shield/audit/AuditTrailModule.java | 6 +-- .../audit/logfile/LoggingAuditTrail.java | 38 ++++++++++++++++--- .../audit/logfile/LoggingAuditTrailTests.java | 38 ++++++++++--------- .../test/ShieldSettingsSource.java | 3 ++ 5 files changed, 70 insertions(+), 29 deletions(-) diff --git a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/ShieldPlugin.java b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/ShieldPlugin.java index a4fb96e0035..aac2975ea7f 100644 --- a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/ShieldPlugin.java +++ b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/ShieldPlugin.java @@ -24,6 +24,7 @@ import org.elasticsearch.shield.action.authc.cache.ClearRealmCacheAction; import org.elasticsearch.shield.action.authc.cache.TransportClearRealmCacheAction; import org.elasticsearch.shield.audit.AuditTrailModule; import org.elasticsearch.shield.audit.index.IndexAuditUserHolder; +import org.elasticsearch.shield.audit.logfile.LoggingAuditTrail; import org.elasticsearch.shield.authc.AuthenticationModule; import org.elasticsearch.shield.authc.Realms; import org.elasticsearch.shield.authc.support.SecuredString; @@ -52,8 +53,6 @@ import org.elasticsearch.xpack.XPackPlugin; import java.nio.file.Path; import java.util.*; -import java.security.AccessController; -import java.security.PrivilegedAction; /** * @@ -121,7 +120,16 @@ public class ShieldPlugin extends Plugin { @Override public Collection> nodeServices() { if (enabled && clientMode == false) { - return Arrays.>asList(ShieldLicensee.class, InternalCryptoService.class, FileRolesStore.class, Realms.class, IPFilter.class); + List> list = new ArrayList<>(); + if (AuditTrailModule.fileAuditLoggingEnabled(settings)) { + list.add(LoggingAuditTrail.class); + } + list.add(ShieldLicensee.class); + list.add(InternalCryptoService.class); + list.add(FileRolesStore.class); + list.add(Realms.class); + list.add(IPFilter.class); + return list; } return Collections.emptyList(); } diff --git a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java index cedfb0fec4d..ed86437ab60 100644 --- a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java +++ b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/AuditTrailModule.java @@ -70,11 +70,11 @@ public class AuditTrailModule extends AbstractShieldModule.Node { return settings.getAsBoolean("shield.audit.enabled", false); } - public static boolean indexAuditLoggingEnabled(Settings settings) { + public static boolean fileAuditLoggingEnabled(Settings settings) { if (auditingEnabled(settings)) { - String[] outputs = settings.getAsArray("shield.audit.outputs"); + String[] outputs = settings.getAsArray("shield.audit.outputs", new String[] { LoggingAuditTrail.NAME }); for (String output : outputs) { - if (output.equals(IndexAuditTrail.NAME)) { + if (output.equals(LoggingAuditTrail.NAME)) { return true; } } diff --git a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java index 8ac867f1380..19a9bfa16af 100644 --- a/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java +++ b/elasticsearch/x-pack/shield/src/main/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrail.java @@ -5,6 +5,9 @@ */ package org.elasticsearch.shield.audit.logfile; +import org.elasticsearch.common.component.AbstractLifecycleComponent; +import org.elasticsearch.common.component.Lifecycle; +import org.elasticsearch.common.component.LifecycleListener; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.logging.ESLogger; import org.elasticsearch.common.logging.Loggers; @@ -33,14 +36,15 @@ import static org.elasticsearch.shield.audit.AuditUtil.restRequestContent; /** * */ -public class LoggingAuditTrail implements AuditTrail { +public class LoggingAuditTrail extends AbstractLifecycleComponent implements AuditTrail { public static final String NAME = "logfile"; - private final String prefix; private final ESLogger logger; private final Transport transport; + private String prefix; + @Override public String name() { return NAME; @@ -48,19 +52,43 @@ public class LoggingAuditTrail implements AuditTrail { @Inject public LoggingAuditTrail(Settings settings, Transport transport) { - this(resolvePrefix(settings, transport), transport, Loggers.getLogger(LoggingAuditTrail.class)); + this(settings, transport, Loggers.getLogger(LoggingAuditTrail.class)); } LoggingAuditTrail(Settings settings, Transport transport, ESLogger logger) { - this(resolvePrefix(settings, transport), transport, logger); + this("", settings, transport, logger); } - LoggingAuditTrail(String prefix, Transport transport, ESLogger logger) { + LoggingAuditTrail(String prefix, Settings settings, Transport transport, ESLogger logger) { + super(settings); this.logger = logger; this.prefix = prefix; this.transport = transport; } + + @Override + protected void doStart() { + if (transport.lifecycleState() == Lifecycle.State.STARTED) { + prefix = resolvePrefix(settings, transport); + } else { + transport.addLifecycleListener(new LifecycleListener() { + @Override + public void afterStart() { + prefix = resolvePrefix(settings, transport); + } + }); + } + } + + @Override + protected void doStop() { + } + + @Override + protected void doClose() { + } + @Override public void anonymousAccessDenied(String action, TransportMessage message) { String indices = indicesString(message); diff --git a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrailTests.java b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrailTests.java index 8ea3f1af7dc..6915a18b2c2 100644 --- a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrailTests.java +++ b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/shield/audit/logfile/LoggingAuditTrailTests.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.IndicesRequest; import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.network.NetworkAddress; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.BoundTransportAddress; @@ -110,6 +111,7 @@ public class LoggingAuditTrailTests extends ESTestCase { .put("shield.audit.logfile.prefix.emit_node_name", randomBoolean()) .build(); transport = mock(Transport.class); + when(transport.lifecycleState()).thenReturn(Lifecycle.State.STARTED); when(transport.boundAddress()).thenReturn(new BoundTransportAddress(new TransportAddress[] { DummyTransportAddress.INSTANCE }, DummyTransportAddress.INSTANCE)); prefix = LoggingAuditTrail.resolvePrefix(settings, transport); } @@ -117,7 +119,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAnonymousAccessDeniedTransport() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport); auditTrail.anonymousAccessDenied("_action", message); @@ -153,7 +155,7 @@ public class LoggingAuditTrailTests extends ESTestCase { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.anonymousAccessDenied(request); switch (level) { case ERROR: @@ -173,7 +175,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAuthenticationFailed() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport);; auditTrail.authenticationFailed(new MockToken(), "_action", message); @@ -201,7 +203,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAuthenticationFailedNoToken() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport);; auditTrail.authenticationFailed("_action", message); @@ -234,7 +236,7 @@ public class LoggingAuditTrailTests extends ESTestCase { when(request.uri()).thenReturn("_uri"); String expectedMessage = prepareRestContent(request); CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.authenticationFailed(new MockToken(), request); switch (level) { case ERROR: @@ -257,7 +259,7 @@ public class LoggingAuditTrailTests extends ESTestCase { when(request.uri()).thenReturn("_uri"); String expectedMessage = prepareRestContent(request); CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.authenticationFailed(request); switch (level) { case ERROR: @@ -275,7 +277,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAuthenticationFailedRealm() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport);; auditTrail.authenticationFailed("_realm", new MockToken(), "_action", message); @@ -304,7 +306,7 @@ public class LoggingAuditTrailTests extends ESTestCase { when(request.uri()).thenReturn("_uri"); String expectedMessage = prepareRestContent(request); CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.authenticationFailed("_realm", new MockToken(), request); switch (level) { case ERROR: @@ -322,7 +324,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAccessGranted() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport); boolean runAs = randomBoolean(); @@ -360,7 +362,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAccessGrantedInternalSystemAction() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport); auditTrail.accessGranted(User.SYSTEM, "internal:_action", message); @@ -384,7 +386,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAccessGrantedInternalSystemActionNonSystemUser() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport); boolean runAs = randomBoolean(); @@ -422,7 +424,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testAccessDenied() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = randomBoolean() ? new MockMessage() : new MockIndicesRequest(); String origins = LoggingAuditTrail.originAttributes(message, transport); boolean runAs = randomBoolean(); @@ -461,7 +463,7 @@ public class LoggingAuditTrailTests extends ESTestCase { String origins = LoggingAuditTrail.originAttributes(message, transport); for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.tamperedRequest(action, message); switch (level) { case ERROR: @@ -498,7 +500,7 @@ public class LoggingAuditTrailTests extends ESTestCase { String userInfo = runAs ? "principal=[running as], run_by_principal=[_username]" : "principal=[_username]"; for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); auditTrail.tamperedRequest(user, action, message); switch (level) { case ERROR: @@ -524,7 +526,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testConnectionDenied() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); InetAddress inetAddress = InetAddress.getLoopbackAddress(); ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all"); auditTrail.connectionDenied(inetAddress, "default", rule); @@ -544,7 +546,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testConnectionGranted() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); InetAddress inetAddress = InetAddress.getLoopbackAddress(); ShieldIpFilterRule rule = IPFilter.DEFAULT_PROFILE_ACCEPT_ALL; auditTrail.connectionGranted(inetAddress, "default", rule); @@ -566,7 +568,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testRunAsGranted() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = new MockMessage(); String origins = LoggingAuditTrail.originAttributes(message, transport); User user = new User.Simple("_username", new String[]{"r1"}, new User.Simple("running as", new String[] {"r2"})); @@ -589,7 +591,7 @@ public class LoggingAuditTrailTests extends ESTestCase { public void testRunAsDenied() throws Exception { for (Level level : Level.values()) { CapturingLogger logger = new CapturingLogger(level); - LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger); + LoggingAuditTrail auditTrail = new LoggingAuditTrail(settings, transport, logger).start(); TransportMessage message = new MockMessage(); String origins = LoggingAuditTrail.originAttributes(message, transport); User user = new User.Simple("_username", new String[]{"r1"}, new User.Simple("running as", new String[] {"r2"})); diff --git a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/test/ShieldSettingsSource.java b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/test/ShieldSettingsSource.java index 658fa3e5b15..a1ca3b46b71 100644 --- a/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/test/ShieldSettingsSource.java +++ b/elasticsearch/x-pack/shield/src/test/java/org/elasticsearch/test/ShieldSettingsSource.java @@ -122,6 +122,9 @@ public class ShieldSettingsSource extends ClusterDiscoveryConfiguration.UnicastZ .put("marvel.enabled", false) .put("shield.audit.enabled", randomBoolean()) + .put("shield.audit.logfile.prefix.emit_node_host_address", randomBoolean()) + .put("shield.audit.logfile.prefix.emit_node_host_name", randomBoolean()) + .put("shield.audit.logfile.prefix.emit_node_name", randomBoolean()) .put(InternalCryptoService.FILE_SETTING, writeFile(folder, "system_key", systemKey)) .put("shield.authc.realms.esusers.type", ESUsersRealm.TYPE) .put("shield.authc.realms.esusers.order", 0)