Security: reorder realms based on last success (#36878)

This commit reorders the realm list for iteration based on the last
successful authentication for the given principal. This is an
optimization to prevent unnecessary iteration over realms if we can
make a smart guess on which realm to try first.
This commit is contained in:
Jay Modi 2019-01-10 09:06:16 -07:00 committed by GitHub
parent fcf7df3eda
commit 71633775fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 257 additions and 3 deletions

View File

@ -463,6 +463,7 @@ public class Security extends Plugin implements ActionPlugin, IngestPlugin, Netw
authcService.set(new AuthenticationService(settings, realms, auditTrailService, failureHandler, threadPool, authcService.set(new AuthenticationService(settings, realms, auditTrailService, failureHandler, threadPool,
anonymousUser, tokenService)); anonymousUser, tokenService));
components.add(authcService.get()); components.add(authcService.get());
securityIndex.get().addIndexStateListener(authcService.get()::onSecurityIndexStateChange);
final NativePrivilegeStore privilegeStore = new NativePrivilegeStore(settings, client, securityIndex.get()); final NativePrivilegeStore privilegeStore = new NativePrivilegeStore(settings, client, securityIndex.get());
components.add(privilegeStore); components.add(privilegeStore);

View File

@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.security.action.realm.ClearRealmCacheAction;
import org.elasticsearch.xpack.core.security.action.realm.ClearRealmCacheRequest; import org.elasticsearch.xpack.core.security.action.realm.ClearRealmCacheRequest;
import org.elasticsearch.xpack.core.security.action.realm.ClearRealmCacheResponse; import org.elasticsearch.xpack.core.security.action.realm.ClearRealmCacheResponse;
import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.Realm;
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.authc.support.CachingRealm; import org.elasticsearch.xpack.security.authc.support.CachingRealm;
@ -26,14 +27,16 @@ public class TransportClearRealmCacheAction extends TransportNodesAction<ClearRe
ClearRealmCacheRequest.Node, ClearRealmCacheResponse.Node> { ClearRealmCacheRequest.Node, ClearRealmCacheResponse.Node> {
private final Realms realms; private final Realms realms;
private final AuthenticationService authenticationService;
@Inject @Inject
public TransportClearRealmCacheAction(ThreadPool threadPool, ClusterService clusterService, TransportService transportService, public TransportClearRealmCacheAction(ThreadPool threadPool, ClusterService clusterService, TransportService transportService,
ActionFilters actionFilters, Realms realms) { ActionFilters actionFilters, Realms realms, AuthenticationService authenticationService) {
super(ClearRealmCacheAction.NAME, threadPool, clusterService, transportService, actionFilters, super(ClearRealmCacheAction.NAME, threadPool, clusterService, transportService, actionFilters,
ClearRealmCacheRequest::new, ClearRealmCacheRequest.Node::new, ThreadPool.Names.MANAGEMENT, ClearRealmCacheRequest::new, ClearRealmCacheRequest.Node::new, ThreadPool.Names.MANAGEMENT,
ClearRealmCacheResponse.Node.class); ClearRealmCacheResponse.Node.class);
this.realms = realms; this.realms = realms;
this.authenticationService = authenticationService;
} }
@Override @Override
@ -68,9 +71,23 @@ public class TransportClearRealmCacheAction extends TransportNodesAction<ClearRe
} }
clearCache(realm, nodeRequest.getUsernames()); clearCache(realm, nodeRequest.getUsernames());
} }
clearAuthenticationServiceCache(nodeRequest.getUsernames());
return new ClearRealmCacheResponse.Node(clusterService.localNode()); return new ClearRealmCacheResponse.Node(clusterService.localNode());
} }
private void clearAuthenticationServiceCache(String[] usernames) {
// this is heavy handed since we could also take realm into account but that would add
// complexity since we would need to iterate over the cache under a lock to remove all
// entries that referenced the specific realm
if (usernames != null && usernames.length != 0) {
for (String username : usernames) {
authenticationService.expire(username);
}
} else {
authenticationService.expireAll();
}
}
private void clearCache(Realm realm, String[] usernames) { private void clearCache(Realm realm, String[] usernames) {
if (!(realm instanceof CachingRealm)) { if (!(realm instanceof CachingRealm)) {
return; return;

View File

@ -13,9 +13,13 @@ import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ContextPreservingActionListener; import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.cache.Cache;
import org.elasticsearch.common.cache.CacheBuilder;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.node.Node; import org.elasticsearch.node.Node;
import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestRequest;
@ -38,14 +42,21 @@ import org.elasticsearch.xpack.security.audit.AuditTrail;
import org.elasticsearch.xpack.security.audit.AuditTrailService; import org.elasticsearch.xpack.security.audit.AuditTrailService;
import org.elasticsearch.xpack.security.audit.AuditUtil; import org.elasticsearch.xpack.security.audit.AuditUtil;
import org.elasticsearch.xpack.security.authc.support.RealmUserLookup; import org.elasticsearch.xpack.security.authc.support.RealmUserLookup;
import org.elasticsearch.xpack.security.support.SecurityIndexManager;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
import static org.elasticsearch.xpack.security.support.SecurityIndexManager.isIndexDeleted;
import static org.elasticsearch.xpack.security.support.SecurityIndexManager.isMoveFromRedToNonRed;
/** /**
* An authentication service that delegates the authentication process to its configured {@link Realm realms}. * An authentication service that delegates the authentication process to its configured {@link Realm realms}.
* This service also supports request level caching of authenticated users (i.e. once a user authenticated * This service also supports request level caching of authenticated users (i.e. once a user authenticated
@ -53,6 +64,12 @@ import java.util.function.Consumer;
*/ */
public class AuthenticationService { public class AuthenticationService {
static final Setting<Boolean> SUCCESS_AUTH_CACHE_ENABLED =
Setting.boolSetting("xpack.security.authc.success_cache.enabled", true, Property.NodeScope);
private static final Setting<Integer> SUCCESS_AUTH_CACHE_MAX_SIZE =
Setting.intSetting("xpack.security.authc.success_cache.size", 10000, Property.NodeScope);
private static final Setting<TimeValue> SUCCESS_AUTH_CACHE_EXPIRE_AFTER_ACCESS =
Setting.timeSetting("xpack.security.authc.success_cache.expire_after_access", TimeValue.timeValueHours(1L), Property.NodeScope);
private static final Logger logger = LogManager.getLogger(AuthenticationService.class); private static final Logger logger = LogManager.getLogger(AuthenticationService.class);
private final Realms realms; private final Realms realms;
@ -62,6 +79,8 @@ public class AuthenticationService {
private final String nodeName; private final String nodeName;
private final AnonymousUser anonymousUser; private final AnonymousUser anonymousUser;
private final TokenService tokenService; private final TokenService tokenService;
private final Cache<String, Realm> lastSuccessfulAuthCache;
private final AtomicLong numInvalidation = new AtomicLong();
private final boolean runAsEnabled; private final boolean runAsEnabled;
private final boolean isAnonymousUserEnabled; private final boolean isAnonymousUserEnabled;
@ -77,6 +96,14 @@ public class AuthenticationService {
this.runAsEnabled = AuthenticationServiceField.RUN_AS_ENABLED.get(settings); this.runAsEnabled = AuthenticationServiceField.RUN_AS_ENABLED.get(settings);
this.isAnonymousUserEnabled = AnonymousUser.isAnonymousEnabled(settings); this.isAnonymousUserEnabled = AnonymousUser.isAnonymousEnabled(settings);
this.tokenService = tokenService; this.tokenService = tokenService;
if (SUCCESS_AUTH_CACHE_ENABLED.get(settings)) {
this.lastSuccessfulAuthCache = CacheBuilder.<String, Realm>builder()
.setMaximumWeight(Integer.toUnsignedLong(SUCCESS_AUTH_CACHE_MAX_SIZE.get(settings)))
.setExpireAfterAccess(SUCCESS_AUTH_CACHE_EXPIRE_AFTER_ACCESS.get(settings))
.build();
} else {
this.lastSuccessfulAuthCache = null;
}
} }
/** /**
@ -120,6 +147,28 @@ public class AuthenticationService {
new Authenticator(action, message, null, listener).authenticateToken(token); new Authenticator(action, message, null, listener).authenticateToken(token);
} }
public void expire(String principal) {
if (lastSuccessfulAuthCache != null) {
numInvalidation.incrementAndGet();
lastSuccessfulAuthCache.invalidate(principal);
}
}
public void expireAll() {
if (lastSuccessfulAuthCache != null) {
numInvalidation.incrementAndGet();
lastSuccessfulAuthCache.invalidateAll();
}
}
public void onSecurityIndexStateChange(SecurityIndexManager.State previousState, SecurityIndexManager.State currentState) {
if (lastSuccessfulAuthCache != null) {
if (isMoveFromRedToNonRed(previousState, currentState) || isIndexDeleted(previousState, currentState)) {
expireAll();
}
}
}
// pkg private method for testing // pkg private method for testing
Authenticator createAuthenticator(RestRequest request, ActionListener<Authentication> listener) { Authenticator createAuthenticator(RestRequest request, ActionListener<Authentication> listener) {
return new Authenticator(request, listener); return new Authenticator(request, listener);
@ -130,6 +179,11 @@ public class AuthenticationService {
return new Authenticator(action, message, fallbackUser, listener); return new Authenticator(action, message, fallbackUser, listener);
} }
// pkg private method for testing
long getNumInvalidation() {
return numInvalidation.get();
}
/** /**
* This class is responsible for taking a request and executing the authentication. The authentication is executed in an asynchronous * This class is responsible for taking a request and executing the authentication. The authentication is executed in an asynchronous
* fashion in order to avoid blocking calls on a network thread. This class also performs the auditing necessary around authentication * fashion in order to avoid blocking calls on a network thread. This class also performs the auditing necessary around authentication
@ -263,7 +317,8 @@ public class AuthenticationService {
handleNullToken(); handleNullToken();
} else { } else {
authenticationToken = token; authenticationToken = token;
final List<Realm> realmsList = realms.asList(); final List<Realm> realmsList = getRealmList(authenticationToken.principal());
final long startInvalidation = numInvalidation.get();
final Map<Realm, Tuple<String, Exception>> messages = new LinkedHashMap<>(); final Map<Realm, Tuple<String, Exception>> messages = new LinkedHashMap<>();
final BiConsumer<Realm, ActionListener<User>> realmAuthenticatingConsumer = (realm, userListener) -> { final BiConsumer<Realm, ActionListener<User>> realmAuthenticatingConsumer = (realm, userListener) -> {
if (realm.supports(authenticationToken)) { if (realm.supports(authenticationToken)) {
@ -273,6 +328,9 @@ public class AuthenticationService {
// user was authenticated, populate the authenticated by information // user was authenticated, populate the authenticated by information
authenticatedBy = new RealmRef(realm.name(), realm.type(), nodeName); authenticatedBy = new RealmRef(realm.name(), realm.type(), nodeName);
authenticationResult = result; authenticationResult = result;
if (lastSuccessfulAuthCache != null && startInvalidation == numInvalidation.get()) {
lastSuccessfulAuthCache.put(authenticationToken.principal(), realm);
}
userListener.onResponse(result.getUser()); userListener.onResponse(result.getUser());
} else { } else {
// the user was not authenticated, call this so we can audit the correct event // the user was not authenticated, call this so we can audit the correct event
@ -313,6 +371,27 @@ public class AuthenticationService {
} }
} }
private List<Realm> getRealmList(String principal) {
final List<Realm> defaultOrderedRealms = realms.asList();
if (lastSuccessfulAuthCache != null) {
final Realm lastSuccess = lastSuccessfulAuthCache.get(principal);
if (lastSuccess != null) {
final int index = defaultOrderedRealms.indexOf(lastSuccess);
if (index > 0) {
final List<Realm> smartOrder = new ArrayList<>(defaultOrderedRealms.size());
smartOrder.add(lastSuccess);
for (int i = 1; i < defaultOrderedRealms.size(); i++) {
if (i != index) {
smartOrder.add(defaultOrderedRealms.get(i));
}
}
return Collections.unmodifiableList(smartOrder);
}
}
}
return defaultOrderedRealms;
}
/** /**
* Handles failed extraction of an authentication token. This can happen in a few different scenarios: * Handles failed extraction of an authentication token. This can happen in a few different scenarios:
* *
@ -391,7 +470,8 @@ public class AuthenticationService {
* names of users that exist using a timing attack * names of users that exist using a timing attack
*/ */
private void lookupRunAsUser(final User user, String runAsUsername, Consumer<User> userConsumer) { private void lookupRunAsUser(final User user, String runAsUsername, Consumer<User> userConsumer) {
final RealmUserLookup lookup = new RealmUserLookup(realms.asList(), threadContext); final RealmUserLookup lookup = new RealmUserLookup(getRealmList(runAsUsername), threadContext);
final long startInvalidationNum = numInvalidation.get();
lookup.lookup(runAsUsername, ActionListener.wrap(tuple -> { lookup.lookup(runAsUsername, ActionListener.wrap(tuple -> {
if (tuple == null) { if (tuple == null) {
// the user does not exist, but we still create a User object, which will later be rejected by authz // the user does not exist, but we still create a User object, which will later be rejected by authz
@ -400,6 +480,11 @@ public class AuthenticationService {
User foundUser = Objects.requireNonNull(tuple.v1()); User foundUser = Objects.requireNonNull(tuple.v1());
Realm realm = Objects.requireNonNull(tuple.v2()); Realm realm = Objects.requireNonNull(tuple.v2());
lookedupBy = new RealmRef(realm.name(), realm.type(), nodeName); lookedupBy = new RealmRef(realm.name(), realm.type(), nodeName);
if (lastSuccessfulAuthCache != null && startInvalidationNum == numInvalidation.get()) {
// only cache this as last success if it doesn't exist since this really isn't an auth attempt but
// this might provide a valid hint
lastSuccessfulAuthCache.computeIfAbsent(runAsUsername, s -> realm);
}
userConsumer.accept(new User(foundUser, user)); userConsumer.accept(new User(foundUser, user));
} }
}, exception -> listener.onFailure(request.exceptionProcessingRequest(exception, authenticationToken)))); }, exception -> listener.onFailure(request.exceptionProcessingRequest(exception, authenticationToken))));
@ -602,5 +687,8 @@ public class AuthenticationService {
public static void addSettings(List<Setting<?>> settings) { public static void addSettings(List<Setting<?>> settings) {
settings.add(AuthenticationServiceField.RUN_AS_ENABLED); settings.add(AuthenticationServiceField.RUN_AS_ENABLED);
settings.add(SUCCESS_AUTH_CACHE_ENABLED);
settings.add(SUCCESS_AUTH_CACHE_MAX_SIZE);
settings.add(SUCCESS_AUTH_CACHE_EXPIRE_AFTER_ACCESS);
} }
} }

View File

@ -22,6 +22,7 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.update.UpdateAction; import org.elasticsearch.action.update.UpdateAction;
import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.health.ClusterHealthStatus;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.SuppressForbidden; import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
@ -103,6 +104,7 @@ import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.reset; import static org.mockito.Mockito.reset;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
@ -133,6 +135,7 @@ public class AuthenticationServiceTests extends ESTestCase {
@SuppressForbidden(reason = "Allow accessing localhost") @SuppressForbidden(reason = "Allow accessing localhost")
public void init() throws Exception { public void init() throws Exception {
token = mock(AuthenticationToken.class); token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
message = new InternalMessage(); message = new InternalMessage();
remoteAddress = new InetSocketAddress(InetAddress.getLocalHost(), 100); remoteAddress = new InetSocketAddress(InetAddress.getLocalHost(), 100);
message.remoteAddress(new TransportAddress(remoteAddress)); message.remoteAddress(new TransportAddress(remoteAddress));
@ -258,6 +261,134 @@ public class AuthenticationServiceTests extends ESTestCase {
verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", message); verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", message);
} }
public void testAuthenticateSmartRealmOrdering() {
User user = new User("_username", "r1");
when(firstRealm.supports(token)).thenReturn(true);
mockAuthenticate(firstRealm, token, null);
when(secondRealm.supports(token)).thenReturn(true);
mockAuthenticate(secondRealm, token, user);
when(secondRealm.token(threadContext)).thenReturn(token);
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate("_action", message, (User)null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.get());
completed.set(false);
service.authenticate("_action", message, (User)null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
verify(auditTrail).authenticationFailed(reqId, firstRealm.name(), token, "_action", message);
verify(auditTrail, times(2)).authenticationSuccess(reqId, secondRealm.name(), user, "_action", message);
verify(firstRealm, times(2)).name(); // used above one time
verify(secondRealm, times(3)).name(); // used above one time
verify(secondRealm, times(2)).type(); // used to create realm ref
verify(firstRealm, times(2)).token(threadContext);
verify(secondRealm, times(2)).token(threadContext);
verify(firstRealm).supports(token);
verify(secondRealm, times(2)).supports(token);
verify(firstRealm).authenticate(eq(token), any(ActionListener.class));
verify(secondRealm, times(2)).authenticate(eq(token), any(ActionListener.class));
verifyNoMoreInteractions(auditTrail, firstRealm, secondRealm);
}
public void testCacheClearOnSecurityIndexChange() {
long expectedInvalidation = 0L;
assertEquals(expectedInvalidation, service.getNumInvalidation());
// existing to no longer present
SecurityIndexManager.State previousState = dummyState(randomFrom(ClusterHealthStatus.GREEN, ClusterHealthStatus.YELLOW));
SecurityIndexManager.State currentState = dummyState(null);
service.onSecurityIndexStateChange(previousState, currentState);
assertEquals(++expectedInvalidation, service.getNumInvalidation());
// doesn't exist to exists
previousState = dummyState(null);
currentState = dummyState(randomFrom(ClusterHealthStatus.GREEN, ClusterHealthStatus.YELLOW));
service.onSecurityIndexStateChange(previousState, currentState);
assertEquals(++expectedInvalidation, service.getNumInvalidation());
// green or yellow to red
previousState = dummyState(randomFrom(ClusterHealthStatus.GREEN, ClusterHealthStatus.YELLOW));
currentState = dummyState(ClusterHealthStatus.RED);
service.onSecurityIndexStateChange(previousState, currentState);
assertEquals(expectedInvalidation, service.getNumInvalidation());
// red to non red
previousState = dummyState(ClusterHealthStatus.RED);
currentState = dummyState(randomFrom(ClusterHealthStatus.GREEN, ClusterHealthStatus.YELLOW));
service.onSecurityIndexStateChange(previousState, currentState);
assertEquals(++expectedInvalidation, service.getNumInvalidation());
// green to yellow or yellow to green
previousState = dummyState(randomFrom(ClusterHealthStatus.GREEN, ClusterHealthStatus.YELLOW));
currentState = dummyState(previousState.indexStatus == ClusterHealthStatus.GREEN ?
ClusterHealthStatus.YELLOW : ClusterHealthStatus.GREEN);
service.onSecurityIndexStateChange(previousState, currentState);
assertEquals(expectedInvalidation, service.getNumInvalidation());
}
public void testAuthenticateSmartRealmOrderingDisabled() {
final Settings settings = Settings.builder()
.put(AuthenticationService.SUCCESS_AUTH_CACHE_ENABLED.getKey(), false)
.build();
service = new AuthenticationService(settings, realms, auditTrail,
new DefaultAuthenticationFailureHandler(Collections.emptyMap()), threadPool, new AnonymousUser(Settings.EMPTY),
tokenService);
User user = new User("_username", "r1");
when(firstRealm.supports(token)).thenReturn(true);
mockAuthenticate(firstRealm, token, null);
when(secondRealm.supports(token)).thenReturn(true);
mockAuthenticate(secondRealm, token, user);
when(secondRealm.token(threadContext)).thenReturn(token);
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
final AtomicBoolean completed = new AtomicBoolean(false);
service.authenticate("_action", message, (User)null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
assertTrue(completed.get());
completed.set(false);
service.authenticate("_action", message, (User)null, ActionListener.wrap(result -> {
assertThat(result, notNullValue());
assertThat(result.getUser(), is(user));
assertThat(result.getLookedUpBy(), is(nullValue()));
assertThat(result.getAuthenticatedBy(), is(notNullValue())); // TODO implement equals
assertThreadContextContainsAuthentication(result);
setCompletedToTrue(completed);
}, this::logAndFail));
verify(auditTrail, times(2)).authenticationFailed(reqId, firstRealm.name(), token, "_action", message);
verify(auditTrail, times(2)).authenticationSuccess(reqId, secondRealm.name(), user, "_action", message);
verify(firstRealm, times(3)).name(); // used above one time
verify(secondRealm, times(3)).name(); // used above one time
verify(secondRealm, times(2)).type(); // used to create realm ref
verify(firstRealm, times(2)).token(threadContext);
verify(secondRealm, times(2)).token(threadContext);
verify(firstRealm, times(2)).supports(token);
verify(secondRealm, times(2)).supports(token);
verify(firstRealm, times(2)).authenticate(eq(token), any(ActionListener.class));
verify(secondRealm, times(2)).authenticate(eq(token), any(ActionListener.class));
verifyNoMoreInteractions(auditTrail, firstRealm, secondRealm);
}
public void testAuthenticateFirstNotSupportingSecondSucceeds() throws Exception { public void testAuthenticateFirstNotSupportingSecondSucceeds() throws Exception {
User user = new User("_username", "r1"); User user = new User("_username", "r1");
when(firstRealm.supports(token)).thenReturn(false); when(firstRealm.supports(token)).thenReturn(false);
@ -614,6 +745,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmSupportsMethodThrowingException() throws Exception { public void testRealmSupportsMethodThrowingException() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenThrow(authenticationError("realm doesn't like supports")); when(secondRealm.supports(token)).thenThrow(authenticationError("realm doesn't like supports"));
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
@ -628,6 +760,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmSupportsMethodThrowingExceptionRest() throws Exception { public void testRealmSupportsMethodThrowingExceptionRest() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenThrow(authenticationError("realm doesn't like supports")); when(secondRealm.supports(token)).thenThrow(authenticationError("realm doesn't like supports"));
try { try {
@ -643,6 +776,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmAuthenticateTerminatingAuthenticationProcess() throws Exception { public void testRealmAuthenticateTerminatingAuthenticationProcess() throws Exception {
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
final AuthenticationToken token = mock(AuthenticationToken.class); final AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
final boolean terminateWithNoException = rarely(); final boolean terminateWithNoException = rarely();
@ -684,6 +818,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmAuthenticateThrowingException() throws Exception { public void testRealmAuthenticateThrowingException() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
doThrow(authenticationError("realm doesn't like authenticate")) doThrow(authenticationError("realm doesn't like authenticate"))
@ -700,6 +835,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmAuthenticateThrowingExceptionRest() throws Exception { public void testRealmAuthenticateThrowingExceptionRest() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
doThrow(authenticationError("realm doesn't like authenticate")) doThrow(authenticationError("realm doesn't like authenticate"))
@ -716,6 +852,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmLookupThrowingException() throws Exception { public void testRealmLookupThrowingException() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
@ -736,6 +873,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRealmLookupThrowingExceptionRest() throws Exception { public void testRealmLookupThrowingExceptionRest() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
@ -755,6 +893,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRunAsLookupSameRealm() throws Exception { public void testRunAsLookupSameRealm() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
@ -803,6 +942,7 @@ public class AuthenticationServiceTests extends ESTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testRunAsLookupDifferentRealm() throws Exception { public void testRunAsLookupDifferentRealm() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
@ -839,6 +979,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRunAsWithEmptyRunAsUsernameRest() throws Exception { public void testRunAsWithEmptyRunAsUsernameRest() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
User user = new User("lookup user", new String[]{"user"}); User user = new User("lookup user", new String[]{"user"});
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, ""); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
@ -857,6 +998,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testRunAsWithEmptyRunAsUsername() throws Exception { public void testRunAsWithEmptyRunAsUsername() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
User user = new User("lookup user", new String[]{"user"}); User user = new User("lookup user", new String[]{"user"});
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, ""); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "");
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
@ -876,6 +1018,7 @@ public class AuthenticationServiceTests extends ESTestCase {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
public void testAuthenticateTransportDisabledRunAsUser() throws Exception { public void testAuthenticateTransportDisabledRunAsUser() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
final String reqId = AuditUtil.getOrGenerateRequestId(threadContext); final String reqId = AuditUtil.getOrGenerateRequestId(threadContext);
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
@ -897,6 +1040,7 @@ public class AuthenticationServiceTests extends ESTestCase {
public void testAuthenticateRestDisabledRunAsUser() throws Exception { public void testAuthenticateRestDisabledRunAsUser() throws Exception {
AuthenticationToken token = mock(AuthenticationToken.class); AuthenticationToken token = mock(AuthenticationToken.class);
when(token.principal()).thenReturn(randomAlphaOfLength(5));
threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as"); threadContext.putHeader(AuthenticationServiceField.RUN_AS_USER_HEADER, "run_as");
when(secondRealm.token(threadContext)).thenReturn(token); when(secondRealm.token(threadContext)).thenReturn(token);
when(secondRealm.supports(token)).thenReturn(true); when(secondRealm.supports(token)).thenReturn(true);
@ -1117,4 +1261,8 @@ public class AuthenticationServiceTests extends ESTestCase {
private void setCompletedToTrue(AtomicBoolean completed) { private void setCompletedToTrue(AtomicBoolean completed) {
assertTrue(completed.compareAndSet(false, true)); assertTrue(completed.compareAndSet(false, true));
} }
private SecurityIndexManager.State dummyState(ClusterHealthStatus indexStatus) {
return new SecurityIndexManager.State(true, true, true, true, null, indexStatus);
}
} }