security: add the proper behavior for the standard license

This change adds the proper behavior for the standard license which is:

* authentication is enabled but only the reserved, native and file realms are available
* authorization is enabled

Features that are disabled:

* auditing
* ip filtering
* custom realms
* LDAP, Active Directory, PKI realms

See elastic/elasticsearch#1263

Original commit: elastic/x-pack-elasticsearch@920c045bf1
This commit is contained in:
jaymode 2016-04-14 15:25:49 -04:00
parent 077599b63f
commit c39b3ba2fc
20 changed files with 415 additions and 133 deletions

View File

@ -36,7 +36,7 @@ public class SecurityFeatureSet implements XPackFeatureSet {
@Override
public boolean available() {
return licenseState != null && licenseState.securityEnabled();
return licenseState != null && licenseState.authenticationAndAuthorizationEnabled();
}
@Override

View File

@ -21,10 +21,29 @@ public class SecurityLicenseState {
protected volatile Status status = Status.ENABLED;
/**
* @return true if the license allows for security features to be enabled (authc, authz, ip filter, audit, etc)
* @return true if authentication and authorization should be enabled. this does not indicate what realms are available
* @see SecurityLicenseState#enabledRealmType() for the enabled realms
*/
public boolean securityEnabled() {
return status.getMode() != OperationMode.BASIC;
public boolean authenticationAndAuthorizationEnabled() {
OperationMode mode = status.getMode();
return mode == OperationMode.STANDARD || mode == OperationMode.GOLD || mode == OperationMode.PLATINUM
|| mode == OperationMode.TRIAL;
}
/**
* @return true if IP filtering should be enabled
*/
public boolean ipFilteringEnabled() {
OperationMode mode = status.getMode();
return mode == OperationMode.GOLD || mode == OperationMode.PLATINUM || mode == OperationMode.TRIAL;
}
/**
* @return true if auditing should be enabled
*/
public boolean auditingEnabled() {
OperationMode mode = status.getMode();
return mode == OperationMode.GOLD || mode == OperationMode.PLATINUM || mode == OperationMode.TRIAL;
}
/**
@ -55,23 +74,31 @@ public class SecurityLicenseState {
}
/**
* Determine if Custom Realms should be enabled.
* <p>
* Custom Realms are only disabled when the mode is not:
* <ul>
* <li>{@link OperationMode#PLATINUM}</li>
* <li>{@link OperationMode#TRIAL}</li>
* </ul>
* Note: This does not consider the <em>state</em> of the license so that Security does not suddenly block requests!
*
* @return {@code true} to enable Custom Realms. Otherwise {@code false}.
* @return the type of realms that are enabled based on the license {@link OperationMode}
*/
public boolean customRealmsEnabled() {
Status status = this.status;
return status.getMode() == OperationMode.TRIAL || status.getMode() == OperationMode.PLATINUM;
public EnabledRealmType enabledRealmType() {
OperationMode mode = status.getMode();
switch (mode) {
case PLATINUM:
case TRIAL:
return EnabledRealmType.ALL;
case GOLD:
return EnabledRealmType.DEFAULT;
case STANDARD:
return EnabledRealmType.NATIVE;
default:
return EnabledRealmType.NONE;
}
}
void updateStatus(Status status) {
this.status = status;
}
public enum EnabledRealmType {
NONE,
NATIVE,
DEFAULT,
ALL
}
}

View File

@ -54,10 +54,10 @@ public class SecurityLicensee extends AbstractLicenseeComponent<SecurityLicensee
case GOLD:
case PLATINUM:
return new String[] {
"The following Shield functionality will be disabled: authentication, authorization, ip filtering, " +
"auditing, SSL will be disabled on node restart. Please restart your node after applying the license.",
"Field and document level access control will be disabled",
"Custom realms will be ignored"
"The following X-Pack security functionality will be disabled: authentication, authorization, " +
"ip filtering, and auditing. Please restart your node after applying the license.",
"Field and document level access control will be disabled.",
"Custom realms will be ignored."
};
}
}
@ -71,12 +71,28 @@ public class SecurityLicensee extends AbstractLicenseeComponent<SecurityLicensee
case TRIAL:
case PLATINUM:
return new String[] {
"Field and document level access control will be disabled",
"Custom realms will be ignored"
"Field and document level access control will be disabled.",
"Custom realms will be ignored."
};
}
}
break;
case STANDARD:
if (currentLicense != null) {
switch (currentLicense.operationMode()) {
case BASIC:
// ^^ though technically it was already disabled, it's not bad to remind them
case GOLD:
case PLATINUM:
case TRIAL:
return new String[] {
"Authentication will be limited to the native realms.",
"IP filtering and auditing will be disabled.",
"Field and document level access control will be disabled.",
"Custom realms will be ignored."
};
}
}
}
return Strings.EMPTY_ARRAY;
}

View File

@ -94,7 +94,7 @@ public class ShieldActionFilter extends AbstractComponent implements ActionFilte
final boolean restoreOriginalContext = threadContext.getHeader(InternalAuthenticationService.USER_KEY) != null ||
threadContext.getTransient(InternalAuthenticationService.USER_KEY) != null;
try {
if (licenseState.securityEnabled()) {
if (licenseState.authenticationAndAuthorizationEnabled()) {
if (AuthorizationUtils.shouldReplaceUserWithSystem(threadContext, action)) {
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
applyInternal(task, action, request, new SigningListener(this, listener, original), chain);

View File

@ -9,6 +9,7 @@ import org.elasticsearch.common.component.AbstractComponent;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.SecurityLicenseState;
import org.elasticsearch.shield.user.User;
import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.transport.filter.ShieldIpFilterRule;
@ -22,6 +23,7 @@ import java.util.Set;
*/
public class AuditTrailService extends AbstractComponent implements AuditTrail {
private final SecurityLicenseState securityLicenseState;
final AuditTrail[] auditTrails;
@Override
@ -30,78 +32,99 @@ public class AuditTrailService extends AbstractComponent implements AuditTrail {
}
@Inject
public AuditTrailService(Settings settings, Set<AuditTrail> auditTrails) {
public AuditTrailService(Settings settings, Set<AuditTrail> auditTrails, SecurityLicenseState licenseState) {
super(settings);
this.auditTrails = auditTrails.toArray(new AuditTrail[auditTrails.size()]);
this.securityLicenseState = licenseState;
}
@Override
public void anonymousAccessDenied(String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.anonymousAccessDenied(action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.anonymousAccessDenied(action, message);
}
}
}
@Override
public void anonymousAccessDenied(RestRequest request) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.anonymousAccessDenied(request);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.anonymousAccessDenied(request);
}
}
}
@Override
public void authenticationFailed(RestRequest request) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(request);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(request);
}
}
}
@Override
public void authenticationFailed(String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(action, message);
}
}
}
@Override
public void authenticationFailed(AuthenticationToken token, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(token, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(token, action, message);
}
}
}
@Override
public void authenticationFailed(String realm, AuthenticationToken token, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(realm, token, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(realm, token, action, message);
}
}
}
@Override
public void authenticationFailed(AuthenticationToken token, RestRequest request) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(token, request);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(token, request);
}
}
}
@Override
public void authenticationFailed(String realm, AuthenticationToken token, RestRequest request) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(realm, token, request);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.authenticationFailed(realm, token, request);
}
}
}
@Override
public void accessGranted(User user, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.accessGranted(user, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.accessGranted(user, action, message);
}
}
}
@Override
public void accessDenied(User user, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.accessDenied(user, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.accessDenied(user, action, message);
}
}
}
@ -114,43 +137,55 @@ public class AuditTrailService extends AbstractComponent implements AuditTrail {
@Override
public void tamperedRequest(String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.tamperedRequest(action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.tamperedRequest(action, message);
}
}
}
@Override
public void tamperedRequest(User user, String action, TransportMessage request) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.tamperedRequest(user, action, request);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.tamperedRequest(user, action, request);
}
}
}
@Override
public void connectionGranted(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.connectionGranted(inetAddress, profile, rule);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.connectionGranted(inetAddress, profile, rule);
}
}
}
@Override
public void connectionDenied(InetAddress inetAddress, String profile, ShieldIpFilterRule rule) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.connectionDenied(inetAddress, profile, rule);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.connectionDenied(inetAddress, profile, rule);
}
}
}
@Override
public void runAsGranted(User user, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.runAsGranted(user, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.runAsGranted(user, action, message);
}
}
}
@Override
public void runAsDenied(User user, String action, TransportMessage message) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.runAsDenied(user, action, message);
if (securityLicenseState.auditingEnabled()) {
for (AuditTrail auditTrail : auditTrails) {
auditTrail.runAsDenied(user, action, message);
}
}
}
}

View File

@ -13,12 +13,14 @@ import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.SettingsModule;
import org.elasticsearch.env.Environment;
import org.elasticsearch.shield.SecurityLicenseState.EnabledRealmType;
import org.elasticsearch.shield.authc.esnative.ReservedRealm;
import org.elasticsearch.shield.authc.esnative.NativeRealm;
import org.elasticsearch.shield.authc.file.FileRealm;
import org.elasticsearch.shield.SecurityLicenseState;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
@ -41,8 +43,10 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
private final ReservedRealm reservedRealm;
protected List<Realm> realms = Collections.emptyList();
// a list of realms that are "internal" in that they are provided by shield and not a third party
// a list of realms that are considered default in that they are provided by x-pack and not a third party
protected List<Realm> internalRealmsOnly = Collections.emptyList();
// a list of realms that are considered native, that is they only interact with x-pack and no 3rd party auth sources
protected List<Realm> nativeRealmsOnly = Collections.emptyList();
@Inject
public Realms(Settings settings, Environment env, Map<String, Realm.Factory> factories, SecurityLicenseState shieldLicenseState,
@ -61,22 +65,30 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
// pre-computing a list of internal only realms allows us to have much cheaper iteration than a custom iterator
// and is also simpler in terms of logic. These lists are small, so the duplication should not be a real issue here
List<Realm> internalRealms = new ArrayList<>();
List<Realm> nativeRealms = new ArrayList<>();
for (Realm realm : realms) {
if (AuthenticationModule.INTERNAL_REALM_TYPES.contains(realm.type())) {
// don't add the reserved realm here otherwise we end up with only this realm...
if (AuthenticationModule.INTERNAL_REALM_TYPES.contains(realm.type()) && ReservedRealm.TYPE.equals(realm.type()) == false) {
internalRealms.add(realm);
}
if (FileRealm.TYPE.equals(realm.type()) || NativeRealm.TYPE.equals(realm.type())) {
nativeRealms.add(realm);
}
}
if (internalRealms.isEmpty()) {
addInternalRealms(internalRealms);
}
for (List<Realm> realmList : Arrays.asList(internalRealms, nativeRealms)) {
if (realmList.isEmpty()) {
addNativeRealms(realmList);
}
if (internalRealms.contains(reservedRealm) == false) {
internalRealms.add(0, reservedRealm);
assert realmList.contains(reservedRealm) == false;
realmList.add(0, reservedRealm);
assert realmList.get(0) == reservedRealm;
}
assert internalRealms.get(0) == reservedRealm;
this.internalRealmsOnly = Collections.unmodifiableList(internalRealms);
this.nativeRealmsOnly = Collections.unmodifiableList(nativeRealms);
}
@Override
@ -89,10 +101,17 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
@Override
public Iterator<Realm> iterator() {
if (shieldLicenseState.customRealmsEnabled()) {
return realms.iterator();
EnabledRealmType enabledRealmType = shieldLicenseState.enabledRealmType();
switch (enabledRealmType) {
case ALL:
return realms.iterator();
case DEFAULT:
return internalRealmsOnly.iterator();
case NATIVE:
return nativeRealmsOnly.iterator();
default:
throw new IllegalStateException("authentication should not be enabled");
}
return internalRealmsOnly.iterator();
}
public Realm realm(String name) {
@ -145,7 +164,7 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
Collections.sort(realms);
} else {
// there is no "realms" configuration, add the defaults
addInternalRealms(realms);
addNativeRealms(realms);
}
// always add built in first!
realms.add(0, reservedRealm);
@ -177,7 +196,7 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
return result != null ? result : Settings.EMPTY;
}
private void addInternalRealms(List<Realm> realms) {
private void addNativeRealms(List<Realm> realms) {
Realm.Factory fileRealm = factories.get(FileRealm.TYPE);
if (fileRealm != null) {
realms.add(fileRealm.createDefault("default_" + FileRealm.TYPE));

View File

@ -58,7 +58,7 @@ public class ShieldRestFilter extends RestFilter {
@Override
public void process(RestRequest request, RestChannel channel, RestFilterChain filterChain) throws Exception {
if (licenseState.securityEnabled()) {
if (licenseState.authenticationAndAuthorizationEnabled()) {
// CORS - allow for preflight unauthenticated OPTIONS request
if (request.method() != RestRequest.Method.OPTIONS) {
if (extractClientCertificate) {

View File

@ -174,7 +174,7 @@ public class ShieldServerTransportService extends TransportService {
@Override
public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
try (ThreadContext.StoredContext ctx = threadContext.newStoredContext()) {
if (licenseState.securityEnabled()) {
if (licenseState.authenticationAndAuthorizationEnabled()) {
String profile = channel.getProfileName();
ServerTransportFilter filter = profileFilters.get(profile);

View File

@ -167,7 +167,7 @@ public class IPFilter {
}
public boolean accept(String profile, InetAddress peerAddress) {
if (licenseState.securityEnabled() == false) {
if (licenseState.ipFilteringEnabled() == false) {
return true;
}

View File

@ -198,7 +198,7 @@ public class LicensingTests extends ShieldIntegTestCase {
assertThat(httpClient().path("/").execute().getStatusCode(), is(200));
// generate a new license with a mode that enables auth
OperationMode mode = randomFrom(OperationMode.GOLD, OperationMode.TRIAL, OperationMode.PLATINUM);
OperationMode mode = randomFrom(OperationMode.GOLD, OperationMode.TRIAL, OperationMode.PLATINUM, OperationMode.STANDARD);
enableLicensing(mode);
assertThat(httpClient().path("/").execute().getStatusCode(), is(401));
}
@ -217,7 +217,7 @@ public class LicensingTests extends ShieldIntegTestCase {
}
// enable a license that enables security
OperationMode mode = randomFrom(OperationMode.GOLD, OperationMode.PLATINUM, OperationMode.TRIAL);
OperationMode mode = randomFrom(OperationMode.GOLD, OperationMode.TRIAL, OperationMode.PLATINUM, OperationMode.STANDARD);
enableLicensing(mode);
try (TransportClient client = TransportClient.builder().settings(builder).addPlugin(XPackPlugin.class).build()) {

View File

@ -6,8 +6,10 @@
package org.elasticsearch.shield;
import org.elasticsearch.license.core.License;
import org.elasticsearch.license.core.License.OperationMode;
import org.elasticsearch.license.plugin.core.LicenseState;
import org.elasticsearch.license.plugin.core.Licensee;
import org.elasticsearch.shield.SecurityLicenseState.EnabledRealmType;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.is;
@ -19,10 +21,12 @@ public class ShieldLicenseStateTests extends ESTestCase {
public void testDefaults() {
SecurityLicenseState licenseState = new SecurityLicenseState();
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(true));
assertThat(licenseState.auditingEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.ALL));
}
public void testBasic() {
@ -30,20 +34,49 @@ public class ShieldLicenseStateTests extends ESTestCase {
licenseState.updateStatus(new Licensee.Status(License.OperationMode.BASIC,
randomBoolean() ? LicenseState.ENABLED : LicenseState.GRACE_PERIOD));
assertThat(licenseState.securityEnabled(), is(false));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(false));
assertThat(licenseState.ipFilteringEnabled(), is(false));
assertThat(licenseState.auditingEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.NONE));
}
public void testBasicExpired() {
SecurityLicenseState licenseState = new SecurityLicenseState();
licenseState.updateStatus(new Licensee.Status(License.OperationMode.BASIC, LicenseState.DISABLED));
assertThat(licenseState.securityEnabled(), is(false));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(false));
assertThat(licenseState.ipFilteringEnabled(), is(false));
assertThat(licenseState.auditingEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.NONE));
}
public void testStandard() {
SecurityLicenseState licenseState = new SecurityLicenseState();
licenseState.updateStatus(new Licensee.Status(OperationMode.STANDARD,
randomBoolean() ? LicenseState.ENABLED : LicenseState.GRACE_PERIOD));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(false));
assertThat(licenseState.auditingEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.NATIVE));
}
public void testStandardExpired() {
SecurityLicenseState licenseState = new SecurityLicenseState();
licenseState.updateStatus(new Licensee.Status(OperationMode.STANDARD, LicenseState.DISABLED));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(false));
assertThat(licenseState.auditingEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.NATIVE));
}
public void testGold() {
@ -51,20 +84,24 @@ public class ShieldLicenseStateTests extends ESTestCase {
licenseState.updateStatus(new Licensee.Status(License.OperationMode.GOLD,
randomBoolean() ? LicenseState.ENABLED : LicenseState.GRACE_PERIOD));
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(true));
assertThat(licenseState.auditingEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.DEFAULT));
}
public void testGoldExpired() {
SecurityLicenseState licenseState = new SecurityLicenseState();
licenseState.updateStatus(new Licensee.Status(License.OperationMode.GOLD, LicenseState.DISABLED));
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(true));
assertThat(licenseState.auditingEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.DEFAULT));
}
public void testPlatinum() {
@ -72,19 +109,23 @@ public class ShieldLicenseStateTests extends ESTestCase {
licenseState.updateStatus(new Licensee.Status(License.OperationMode.PLATINUM,
randomBoolean() ? LicenseState.ENABLED : LicenseState.GRACE_PERIOD));
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(true));
assertThat(licenseState.auditingEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.ALL));
}
public void testPlatinumExpired() {
SecurityLicenseState licenseState = new SecurityLicenseState();
licenseState.updateStatus(new Licensee.Status(License.OperationMode.PLATINUM, LicenseState.DISABLED));
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.authenticationAndAuthorizationEnabled(), is(true));
assertThat(licenseState.ipFilteringEnabled(), is(true));
assertThat(licenseState.auditingEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
assertThat(licenseState.enabledRealmType(), is(EnabledRealmType.ALL));
}
}

View File

@ -58,8 +58,10 @@ public class ShieldLicenseeTests extends AbstractLicenseeTestCase {
verifyNoMoreInteractions(registry, shieldState);
}
public void testAcknowledgementMessagesFromBasicToAnyNotGoldIsNoOp() {
assertEmptyAck(OperationMode.BASIC, randomModeExcept(OperationMode.GOLD), this::buildLicensee);
public void testAcknowledgementMessagesFromBasicToAnyNotGoldOrStandardIsNoOp() {
assertEmptyAck(OperationMode.BASIC,
randomFrom(OperationMode.values(), mode -> mode != OperationMode.GOLD && mode != OperationMode.STANDARD),
this::buildLicensee);
}
public void testAcknowledgementMessagesFromAnyToTrialOrPlatinumIsNoOp() {
@ -76,6 +78,16 @@ public class ShieldLicenseeTests extends AbstractLicenseeTestCase {
assertThat(fromToMessage(from, to), messages.length, equalTo(3));
}
public void testAcknowlegmentMessagesFromAnyToStandardNotesLimits() {
OperationMode from = randomFrom(OperationMode.BASIC, OperationMode.GOLD, OperationMode.PLATINUM, OperationMode.TRIAL);
OperationMode to = OperationMode.STANDARD;
String[] messages = ackLicenseChange(from, to, this::buildLicensee);
// leaving messages up to inspection
assertThat(fromToMessage(from, to), messages.length, equalTo(4));
}
public void testAcknowledgementMessagesFromBasicStandardTrialOrPlatinumToGoldNotesLimits() {
String[] messages = ackLicenseChange(randomModeExcept(OperationMode.GOLD), OperationMode.GOLD, this::buildLicensee);

View File

@ -47,7 +47,7 @@ public class ShieldActionFilterTests extends ESTestCase {
private AuthorizationService authzService;
private CryptoService cryptoService;
private AuditTrail auditTrail;
private SecurityLicenseState shieldLicenseState;
private SecurityLicenseState securityLicenseState;
private ShieldActionFilter filter;
@Before
@ -56,12 +56,12 @@ public class ShieldActionFilterTests extends ESTestCase {
authzService = mock(AuthorizationService.class);
cryptoService = mock(CryptoService.class);
auditTrail = mock(AuditTrail.class);
shieldLicenseState = mock(SecurityLicenseState.class);
when(shieldLicenseState.securityEnabled()).thenReturn(true);
when(shieldLicenseState.statsAndHealthEnabled()).thenReturn(true);
securityLicenseState = mock(SecurityLicenseState.class);
when(securityLicenseState.authenticationAndAuthorizationEnabled()).thenReturn(true);
when(securityLicenseState.statsAndHealthEnabled()).thenReturn(true);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
filter = new ShieldActionFilter(Settings.EMPTY, authcService, authzService, cryptoService, auditTrail, shieldLicenseState,
filter = new ShieldActionFilter(Settings.EMPTY, authcService, authzService, cryptoService, auditTrail, securityLicenseState,
new ShieldActionMapper(), new HashSet<>(), threadPool);
}
@ -128,7 +128,7 @@ public class ShieldActionFilterTests extends ESTestCase {
ActionListener listener = mock(ActionListener.class);
ActionFilterChain chain = mock(ActionFilterChain.class);
Task task = mock(Task.class);
when(shieldLicenseState.securityEnabled()).thenReturn(false);
when(securityLicenseState.authenticationAndAuthorizationEnabled()).thenReturn(false);
filter.apply(task, "_action", request, listener, chain);
verifyZeroInteractions(authcService);
verifyZeroInteractions(authzService);

View File

@ -7,6 +7,7 @@ package org.elasticsearch.shield.audit;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.SecurityLicenseState;
import org.elasticsearch.shield.user.User;
import org.elasticsearch.shield.authc.AuthenticationToken;
import org.elasticsearch.shield.transport.filter.IPFilter;
@ -22,6 +23,8 @@ import java.util.Set;
import static java.util.Collections.unmodifiableSet;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions;
import static org.mockito.Mockito.when;
/**
*
@ -33,6 +36,8 @@ public class AuditTrailServiceTests extends ESTestCase {
private AuthenticationToken token;
private TransportMessage message;
private RestRequest restRequest;
private SecurityLicenseState securityLicenseState;
private boolean auditingEnabled;
@Before
public void init() throws Exception {
@ -41,7 +46,10 @@ public class AuditTrailServiceTests extends ESTestCase {
auditTrailsBuilder.add(mock(AuditTrail.class));
}
auditTrails = unmodifiableSet(auditTrailsBuilder);
service = new AuditTrailService(Settings.EMPTY, auditTrails);
securityLicenseState = mock(SecurityLicenseState.class);
service = new AuditTrailService(Settings.EMPTY, auditTrails, securityLicenseState);
auditingEnabled = randomBoolean();
when(securityLicenseState.auditingEnabled()).thenReturn(auditingEnabled);
token = mock(AuthenticationToken.class);
message = mock(TransportMessage.class);
restRequest = mock(RestRequest.class);
@ -49,66 +57,111 @@ public class AuditTrailServiceTests extends ESTestCase {
public void testAuthenticationFailed() throws Exception {
service.authenticationFailed(token, "_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(token, "_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(token, "_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAuthenticationFailedNoToken() throws Exception {
service.authenticationFailed("_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAuthenticationFailedRestNoToken() throws Exception {
service.authenticationFailed(restRequest);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(restRequest);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(restRequest);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAuthenticationFailedRest() throws Exception {
service.authenticationFailed(token, restRequest);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(token, restRequest);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed(token, restRequest);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAuthenticationFailedRealm() throws Exception {
service.authenticationFailed("_realm", token, "_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_realm", token, "_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_realm", token, "_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAuthenticationFailedRestRealm() throws Exception {
service.authenticationFailed("_realm", token, restRequest);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_realm", token, restRequest);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).authenticationFailed("_realm", token, restRequest);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAnonymousAccess() throws Exception {
service.anonymousAccessDenied("_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).anonymousAccessDenied("_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).anonymousAccessDenied("_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAccessGranted() throws Exception {
User user = new User("_username", "r1");
service.accessGranted(user, "_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).accessGranted(user, "_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).accessGranted(user, "_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
public void testAccessDenied() throws Exception {
User user = new User("_username", "r1");
service.accessDenied(user, "_action", message);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).accessDenied(user, "_action", message);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).accessDenied(user, "_action", message);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
@ -116,8 +169,13 @@ public class AuditTrailServiceTests extends ESTestCase {
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = randomBoolean() ? ShieldIpFilterRule.ACCEPT_ALL : IPFilter.DEFAULT_PROFILE_ACCEPT_ALL;
service.connectionGranted(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).connectionGranted(inetAddress, "client", rule);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).connectionGranted(inetAddress, "client", rule);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
@ -125,8 +183,13 @@ public class AuditTrailServiceTests extends ESTestCase {
InetAddress inetAddress = InetAddress.getLoopbackAddress();
ShieldIpFilterRule rule = new ShieldIpFilterRule(false, "_all");
service.connectionDenied(inetAddress, "client", rule);
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).connectionDenied(inetAddress, "client", rule);
verify(securityLicenseState).auditingEnabled();
if (auditingEnabled) {
for (AuditTrail auditTrail : auditTrails) {
verify(auditTrail).connectionDenied(inetAddress, "client", rule);
}
} else {
verifyZeroInteractions(auditTrails.toArray((Object[]) new AuditTrail[auditTrails.size()]));
}
}
}

View File

@ -17,6 +17,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.rest.RestController;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.authc.InternalAuthenticationService.AuditableRequest;
import org.elasticsearch.shield.SecurityLicenseState.EnabledRealmType;
import org.elasticsearch.shield.user.AnonymousUser;
import org.elasticsearch.shield.user.SystemUser;
import org.elasticsearch.shield.user.User;
@ -92,7 +93,7 @@ public class InternalAuthenticationServiceTests extends ESTestCase {
when(secondRealm.name()).thenReturn("second");
Settings settings = Settings.builder().put("path.home", createTempDir()).build();
SecurityLicenseState shieldLicenseState = mock(SecurityLicenseState.class);
when(shieldLicenseState.customRealmsEnabled()).thenReturn(true);
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.ALL);
realms = new Realms(Settings.EMPTY, new Environment(settings), Collections.<String, Realm.Factory>emptyMap(), shieldLicenseState,
mock(ReservedRealm.class)) {

View File

@ -8,6 +8,7 @@ package org.elasticsearch.shield.authc;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.env.Environment;
import org.elasticsearch.shield.SecurityLicenseState.EnabledRealmType;
import org.elasticsearch.shield.user.User;
import org.elasticsearch.shield.authc.esnative.ReservedRealm;
import org.elasticsearch.shield.authc.esnative.NativeRealm;
@ -51,7 +52,7 @@ public class RealmsTests extends ESTestCase {
}
shieldLicenseState = mock(SecurityLicenseState.class);
reservedRealm = mock(ReservedRealm.class);
when(shieldLicenseState.customRealmsEnabled()).thenReturn(true);
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.ALL);
}
public void testWithSettings() throws Exception {
@ -159,7 +160,23 @@ public class RealmsTests extends ESTestCase {
i++;
}
when(shieldLicenseState.customRealmsEnabled()).thenReturn(false);
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.DEFAULT);
iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm, is(reservedRealm));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), equalTo(FileRealm.TYPE));
assertThat(realm.name(), equalTo("default_" + FileRealm.TYPE));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), equalTo(NativeRealm.TYPE));
assertThat(realm.name(), equalTo("default_" + NativeRealm.TYPE));
assertThat(iter.hasNext(), is(false));
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.NATIVE);
iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
@ -204,7 +221,7 @@ public class RealmsTests extends ESTestCase {
}
assertThat(types, contains("ldap", "type_0"));
when(shieldLicenseState.customRealmsEnabled()).thenReturn(false);
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.DEFAULT);
iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
realm = iter.next();
@ -216,6 +233,57 @@ public class RealmsTests extends ESTestCase {
i++;
}
assertThat(i, is(1));
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.NATIVE);
iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm, is(reservedRealm));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), equalTo(FileRealm.TYPE));
assertThat(realm.name(), equalTo("default_" + FileRealm.TYPE));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), equalTo(NativeRealm.TYPE));
assertThat(realm.name(), equalTo("default_" + NativeRealm.TYPE));
assertThat(iter.hasNext(), is(false));
}
public void testUnlicensedWithNativeRealms() throws Exception {
factories.put(LdapRealm.TYPE, new DummyRealm.Factory(LdapRealm.TYPE, false));
final String type = randomFrom(FileRealm.TYPE, NativeRealm.TYPE);
Settings.Builder builder = Settings.builder()
.put("path.home", createTempDir())
.put("xpack.security.authc.realms.foo.type", "ldap")
.put("xpack.security.authc.realms.foo.order", "0")
.put("xpack.security.authc.realms.native.type", type)
.put("xpack.security.authc.realms.native.order", "1");
Settings settings = builder.build();
Environment env = new Environment(settings);
Realms realms = new Realms(settings, env, factories, shieldLicenseState, reservedRealm);
realms.start();
Iterator<Realm> iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
Realm realm = iter.next();
assertThat(realm, is(reservedRealm));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), is("ldap"));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), is(type));
assertThat(iter.hasNext(), is(false));
when(shieldLicenseState.enabledRealmType()).thenReturn(EnabledRealmType.NATIVE);
iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm, is(reservedRealm));
assertThat(iter.hasNext(), is(true));
realm = iter.next();
assertThat(realm.type(), is(type));
assertThat(iter.hasNext(), is(false));
}
public void testDisabledRealmsAreNotAdded() throws Exception {

View File

@ -43,7 +43,7 @@ public class ShieldRestFilterTests extends ESTestCase {
channel = mock(RestChannel.class);
chain = mock(RestFilterChain.class);
licenseState = mock(SecurityLicenseState.class);
when(licenseState.securityEnabled()).thenReturn(true);
when(licenseState.authenticationAndAuthorizationEnabled()).thenReturn(true);
ThreadPool threadPool = mock(ThreadPool.class);
when(threadPool.getThreadContext()).thenReturn(new ThreadContext(Settings.EMPTY));
filter = new ShieldRestFilter(authcService, restController, Settings.EMPTY, threadPool, licenseState);
@ -61,7 +61,7 @@ public class ShieldRestFilterTests extends ESTestCase {
public void testProcessBasicLicense() throws Exception {
RestRequest request = mock(RestRequest.class);
when(licenseState.securityEnabled()).thenReturn(false);
when(licenseState.authenticationAndAuthorizationEnabled()).thenReturn(false);
filter.process(request, channel, chain);
verify(chain).continueProcessing(request, channel);
verifyZeroInteractions(channel, authcService);

View File

@ -311,7 +311,7 @@ public class TransportFilterTests extends ESIntegTestCase {
ClientTransportFilter clientTransportFilter) {
super(settings, transport, threadPool, authcService, authzService, actionMapper, clientTransportFilter,
mock(SecurityLicenseState.class));
when(licenseState.securityEnabled()).thenReturn(true);
when(licenseState.authenticationAndAuthorizationEnabled()).thenReturn(true);
}
@Override

View File

@ -53,7 +53,7 @@ public class IPFilterTests extends ESTestCase {
@Before
public void init() {
licenseState = mock(SecurityLicenseState.class);
when(licenseState.securityEnabled()).thenReturn(true);
when(licenseState.ipFilteringEnabled()).thenReturn(true);
auditTrail = mock(AuditTrail.class);
clusterSettings = new ClusterSettings(Settings.EMPTY, new HashSet<>(Arrays.asList(
IPFilter.HTTP_FILTER_ALLOW_SETTING,
@ -183,7 +183,7 @@ public class IPFilterTests extends ESTestCase {
.put("xpack.security.transport.filter.deny", "10.0.0.0/8")
.build();
ipFilter = new IPFilter(settings, auditTrail, clusterSettings, licenseState);
ipFilter.setBoundHttpTransportAddress(httpTransport.boundAddress());
ipFilter.setBoundHttpTransportAddress(httpTransport.boundAddress());
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
assertAddressIsAllowedForProfile(IPFilter.HTTP_PROFILE_NAME, "127.0.0.1");
@ -205,7 +205,7 @@ public class IPFilterTests extends ESTestCase {
}
ipFilter = new IPFilter(settings, auditTrail, clusterSettings, licenseState);
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
ipFilter.setBoundHttpTransportAddress(httpTransport.boundAddress());
ipFilter.setBoundHttpTransportAddress(httpTransport.boundAddress());
for (String addressString : addressStrings) {
assertAddressIsAllowedForProfile(IPFilter.HTTP_PROFILE_NAME, addressString);
@ -217,7 +217,7 @@ public class IPFilterTests extends ESTestCase {
Settings settings = Settings.builder()
.put("xpack.security.transport.filter.deny", "_all")
.build();
when(licenseState.securityEnabled()).thenReturn(false);
when(licenseState.ipFilteringEnabled()).thenReturn(false);
ipFilter = new IPFilter(settings, auditTrail, clusterSettings, licenseState);
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
@ -228,7 +228,7 @@ public class IPFilterTests extends ESTestCase {
verifyZeroInteractions(auditTrail);
// for sanity enable license and check that it is denied
when(licenseState.securityEnabled()).thenReturn(true);
when(licenseState.ipFilteringEnabled()).thenReturn(true);
ipFilter = new IPFilter(settings, auditTrail, clusterSettings, licenseState);
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());

View File

@ -69,7 +69,7 @@ public class IPFilterNettyUpstreamHandlerTests extends ESTestCase {
IPFilter.TRANSPORT_FILTER_DENY_SETTING,
TransportSettings.TRANSPORT_PROFILES_SETTING)));
SecurityLicenseState licenseState = mock(SecurityLicenseState.class);
when(licenseState.securityEnabled()).thenReturn(true);
when(licenseState.ipFilteringEnabled()).thenReturn(true);
IPFilter ipFilter = new IPFilter(settings, AuditTrail.NOOP, clusterSettings, licenseState);
ipFilter.setBoundTransportAddress(transport.boundAddress(), transport.profileBoundAddresses());
if (isHttpEnabled) {