shield: only enable custom realms with a platinum or trial license

In elastic/elasticsearch#788, the enabling and disabling of features was added for shield, but custom realms were not
being enabled/disabled based on license type. This commit adds that functionality.

Relates to elastic/elasticsearch#689

Original commit: elastic/x-pack-elasticsearch@625c3ef18a
This commit is contained in:
jaymode 2015-10-14 12:36:38 -04:00
parent b927fd08bc
commit f19e68ecb8
7 changed files with 138 additions and 17 deletions

View File

@ -21,7 +21,7 @@ import java.util.Map.Entry;
*/
public class AuthenticationModule extends AbstractShieldModule.Node {
private static final List<String> INTERNAL_REALM_TYPES = Arrays.asList(ESUsersRealm.TYPE, ActiveDirectoryRealm.TYPE, LdapRealm.TYPE, PkiRealm.TYPE);
static final List<String> INTERNAL_REALM_TYPES = Arrays.asList(ESUsersRealm.TYPE, ActiveDirectoryRealm.TYPE, LdapRealm.TYPE, PkiRealm.TYPE);
private final Map<String, Class<? extends Realm.Factory<? extends Realm<? extends AuthenticationToken>>>> customRealms = new HashMap<>();

View File

@ -12,9 +12,9 @@ import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.env.Environment;
import org.elasticsearch.shield.ShieldSettingsFilter;
import org.elasticsearch.shield.authc.esusers.ESUsersRealm;
import org.elasticsearch.shield.license.ShieldLicenseState;
import java.util.*;
import java.util.concurrent.CopyOnWriteArrayList;
/**
* Serves as a realms registry (also responsible for ordering the realms appropriately)
@ -24,20 +24,40 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
private final Environment env;
private final Map<String, Realm.Factory> factories;
private final ShieldSettingsFilter settingsFilter;
private final ShieldLicenseState shieldLicenseState;
private List<Realm> realms = Collections.emptyList();
protected List<Realm> realms = Collections.emptyList();
// a list of realms that are "internal" in that they are provided by shield and not a third party
protected List<Realm> internalRealmsOnly = Collections.emptyList();
@Inject
public Realms(Settings settings, Environment env, Map<String, Realm.Factory> factories, ShieldSettingsFilter settingsFilter) {
public Realms(Settings settings, Environment env, Map<String, Realm.Factory> factories,
ShieldSettingsFilter settingsFilter, ShieldLicenseState shieldLicenseState) {
super(settings);
this.env = env;
this.factories = factories;
this.settingsFilter = settingsFilter;
this.shieldLicenseState = shieldLicenseState;
}
@Override
protected void doStart() throws ElasticsearchException {
realms = new CopyOnWriteArrayList<>(initRealms());
this.realms = initRealms();
// pre-computing a list of internal only realms allows us to have much cheaper iteration than a custom iterator
// and is also simpler in terms of logic. These lists are small, so the duplication should not be a real issue here
List<Realm> internalRealms = new ArrayList<>();
for (Realm realm : realms) {
if (AuthenticationModule.INTERNAL_REALM_TYPES.contains(realm.type())) {
internalRealms.add(realm);
}
}
if (internalRealms.isEmpty()) {
// lets create a default one so they can do something
internalRealms.add(factories.get(ESUsersRealm.TYPE).createDefault("default_" + ESUsersRealm.TYPE));
}
this.internalRealmsOnly = Collections.unmodifiableList(internalRealms);
}
@Override
@ -48,7 +68,10 @@ public class Realms extends AbstractLifecycleComponent<Realms> implements Iterab
@Override
public Iterator<Realm> iterator() {
return realms.iterator();
if (shieldLicenseState.customRealmsEnabled()) {
return realms.iterator();
}
return internalRealmsOnly.iterator();
}
public Realm realm(String name) {

View File

@ -44,6 +44,14 @@ public class ShieldLicenseState {
return status.getMode() == OperationMode.PLATINUM || status.getMode() == OperationMode.TRIAL;
}
/**
* @return true if the license enables the use of custom authentication realms
*/
public boolean customRealmsEnabled() {
Status status = this.status;
return status.getMode() == OperationMode.PLATINUM || status.getMode() == OperationMode.TRIAL;
}
void updateStatus(Status status) {
this.status = status;
}

View File

@ -63,7 +63,8 @@ public class ShieldLicensee extends AbstractLicenseeComponent<ShieldLicensee> im
case BASIC:
case PLATINUM:
return new String[] {
"Field and document level access control will be disabled"
"Field and document level access control will be disabled",
"Custom realms will be ignored"
};
}
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.shield.audit.AuditTrail;
import org.elasticsearch.shield.authc.support.SecuredString;
import org.elasticsearch.shield.authc.support.UsernamePasswordToken;
import org.elasticsearch.shield.crypto.CryptoService;
import org.elasticsearch.shield.license.ShieldLicenseState;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.rest.FakeRestRequest;
import org.elasticsearch.transport.TransportMessage;
@ -28,7 +29,6 @@ import org.junit.rules.ExpectedException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.elasticsearch.shield.support.Exceptions.authenticationError;
import static org.elasticsearch.test.ShieldTestsUtils.assertAuthenticationException;
@ -61,15 +61,20 @@ public class InternalAuthenticationServiceTests extends ESTestCase {
message = new InternalMessage();
restRequest = new FakeRestRequest();
firstRealm = mock(Realm.class);
when(firstRealm.type()).thenReturn("first");
when(firstRealm.type()).thenReturn("esusers");
secondRealm = mock(Realm.class);
when(secondRealm.type()).thenReturn("second");
Settings settings = Settings.builder().put("path.home", createTempDir()).build();
realms = new Realms(Settings.EMPTY, new Environment(settings), Collections.<String, Realm.Factory>emptyMap(), mock(ShieldSettingsFilter.class)) {
ShieldLicenseState shieldLicenseState = mock(ShieldLicenseState.class);
when(shieldLicenseState.customRealmsEnabled()).thenReturn(true);
realms = new Realms(Settings.EMPTY, new Environment(settings), Collections.<String, Realm.Factory>emptyMap(), mock(ShieldSettingsFilter.class), shieldLicenseState) {
@Override
protected List<Realm> initRealms() {
return Arrays.asList(firstRealm, secondRealm);
protected void doStart() {
this.realms = Arrays.asList(firstRealm, secondRealm);
this.internalRealmsOnly = Collections.singletonList(firstRealm);
}
};
realms.start();
cryptoService = mock(CryptoService.class);
@ -127,7 +132,7 @@ public class InternalAuthenticationServiceTests extends ESTestCase {
User result = service.authenticate("_action", message, null);
assertThat(result, notNullValue());
assertThat(result, is(user));
verify(auditTrail).authenticationFailed("first", token, "_action", message);
verify(auditTrail).authenticationFailed("esusers", token, "_action", message);
assertThat(message.getContext().get(InternalAuthenticationService.USER_KEY), notNullValue());
assertThat(message.getContext().get(InternalAuthenticationService.USER_KEY), sameInstance((Object) user));
assertThat(message.getHeader(InternalAuthenticationService.USER_KEY), equalTo((Object) "_encoded_user"));

View File

@ -11,6 +11,8 @@ import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.shield.ShieldSettingsFilter;
import org.elasticsearch.shield.User;
import org.elasticsearch.shield.authc.esusers.ESUsersRealm;
import org.elasticsearch.shield.authc.ldap.LdapRealm;
import org.elasticsearch.shield.license.ShieldLicenseState;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.TransportMessage;
import org.junit.Before;
@ -20,6 +22,7 @@ import java.util.*;
import static org.hamcrest.Matchers.*;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
/**
*
@ -28,6 +31,7 @@ public class RealmsTests extends ESTestCase {
private Map<String, Realm.Factory> factories;
private ShieldSettingsFilter settingsFilter;
private ShieldLicenseState shieldLicenseState;
@Before
public void init() throws Exception {
@ -38,6 +42,8 @@ public class RealmsTests extends ESTestCase {
factories.put("type_" + i, factory);
}
settingsFilter = mock(ShieldSettingsFilter.class);
shieldLicenseState = mock(ShieldLicenseState.class);
when(shieldLicenseState.customRealmsEnabled()).thenReturn(true);
}
@Test
@ -57,7 +63,7 @@ public class RealmsTests extends ESTestCase {
}
Settings settings = builder.build();
Environment env = new Environment(settings);
Realms realms = new Realms(settings, env, factories, settingsFilter);
Realms realms = new Realms(settings, env, factories, settingsFilter, shieldLicenseState);
realms.start();
int i = 0;
for (Realm realm : realms) {
@ -79,12 +85,13 @@ public class RealmsTests extends ESTestCase {
.put("path.home", createTempDir())
.build();
Environment env = new Environment(settings);
new Realms(settings, env, factories, settingsFilter).start();
new Realms(settings, env, factories, settingsFilter, shieldLicenseState).start();
}
@Test
public void testWithEmptySettings() throws Exception {
Realms realms = new Realms(Settings.EMPTY, new Environment(Settings.builder().put("path.home", createTempDir()).build()), factories, settingsFilter);
Realms realms = new Realms(Settings.EMPTY, new Environment(Settings.builder().put("path.home", createTempDir()).build()),
factories, settingsFilter, shieldLicenseState);
realms.start();
Iterator<Realm> iter = realms.iterator();
assertThat(iter.hasNext(), is(true));
@ -95,6 +102,76 @@ public class RealmsTests extends ESTestCase {
assertThat(iter.hasNext(), is(false));
}
@Test
public void testUnlicensedWithOnlyCustomRealms() throws Exception {
Settings.Builder builder = Settings.builder()
.put("path.home", createTempDir());
List<Integer> orders = new ArrayList<>(factories.size() - 1);
for (int i = 0; i < factories.size() - 1; i++) {
orders.add(i);
}
Collections.shuffle(orders, getRandom());
Map<Integer, Integer> orderToIndex = new HashMap<>();
for (int i = 0; i < factories.size() - 1; i++) {
builder.put("shield.authc.realms.realm_" + i + ".type", "type_" + i);
builder.put("shield.authc.realms.realm_" + i + ".order", orders.get(i));
orderToIndex.put(orders.get(i), i);
}
Settings settings = builder.build();
Environment env = new Environment(settings);
Realms realms = new Realms(settings, env, factories, settingsFilter, shieldLicenseState);
realms.start();
int i = 0;
// this is the iterator when licensed
for (Realm realm : realms) {
assertThat(realm.order(), equalTo(i));
int index = orderToIndex.get(i);
assertThat(realm.type(), equalTo("type_" + index));
assertThat(realm.name(), equalTo("realm_" + index));
i++;
}
i = 0;
when(shieldLicenseState.customRealmsEnabled()).thenReturn(false);
for (Realm realm : realms) {
assertThat(realm.type, is(ESUsersRealm.TYPE));
i++;
}
assertThat(i, is(1));
}
@Test
public void testUnlicensedWithInternalRealms() throws Exception {
factories.put(LdapRealm.TYPE, new DummyRealm.Factory(LdapRealm.TYPE, false));
assertThat(factories.get("type_1"), notNullValue());
Settings.Builder builder = Settings.builder()
.put("path.home", createTempDir())
.put("shield.authc.realms.foo.type", "ldap")
.put("shield.authc.realms.foo.order", "0")
.put("shield.authc.realms.custom.type", "type_1")
.put("shield.authc.realms.custom.order", "1");
Settings settings = builder.build();
Environment env = new Environment(settings);
Realms realms = new Realms(settings, env, factories, settingsFilter, shieldLicenseState);
realms.start();
int i = 0;
// this is the iterator when licensed
List<String> types = new ArrayList<>();
for (Realm realm : realms) {
i++;
types.add(realm.type());
}
assertThat(types, contains("ldap", "type_1"));
i = 0;
when(shieldLicenseState.customRealmsEnabled()).thenReturn(false);
for (Realm realm : realms) {
assertThat(realm.type, is("ldap"));
i++;
}
assertThat(i, is(1));
}
@Test
public void testDisabledRealmsAreNotAdded() throws Exception {
Settings.Builder builder = Settings.builder()
@ -117,7 +194,7 @@ public class RealmsTests extends ESTestCase {
}
Settings settings = builder.build();
Environment env = new Environment(settings);
Realms realms = new Realms(settings, env, factories, mock(ShieldSettingsFilter.class));
Realms realms = new Realms(settings, env, factories, mock(ShieldSettingsFilter.class), shieldLicenseState);
realms.start();
Iterator<Realm> iterator = realms.iterator();

View File

@ -22,6 +22,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
}
public void testBasic() {
@ -31,6 +32,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
}
public void testBasicExpired() {
@ -40,6 +42,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(false));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
}
public void testGold() {
@ -49,6 +52,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
}
public void testGoldExpired() {
@ -58,6 +62,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(false));
assertThat(licenseState.customRealmsEnabled(), is(false));
}
public void testPlatinum() {
@ -67,6 +72,7 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(true));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
}
public void testPlatinumExpired() {
@ -76,5 +82,6 @@ public class ShieldLicenseStateTests extends ESTestCase {
assertThat(licenseState.securityEnabled(), is(true));
assertThat(licenseState.statsAndHealthEnabled(), is(false));
assertThat(licenseState.documentAndFieldLevelSecurityEnabled(), is(true));
assertThat(licenseState.customRealmsEnabled(), is(true));
}
}