Refactor Token Service (#39808)

This refactoring is in the context of the work related to moving security
tokens to a new index. In that regard, the Token Service has to work with
token documents stored in any of the two indices, albeit only as a transient
situation. I reckoned the added complexity as unmanageable,
hence this refactoring.

This is incomplete, as it fails to address the goal of minimizing .security accesses,
but I have stopped because otherwise it would've become a full blown rewrite
(if not already). I will follow-up with more targeted PRs.

In addition to being a true refactoring, some 400 errors moved to 500. Furthermore,
more stringed validation of various return result, has been implemented, notably the
one of the token document creation.
This commit is contained in:
Albert Zaharovits 2019-03-21 15:55:56 +02:00 committed by GitHub
parent 244e6758cd
commit 2f80b7304f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 517 additions and 490 deletions

View File

@ -61,13 +61,13 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio
}
assert authentication != null : "authentication should never be null at this point";
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createUserToken(authentication, originatingAuthentication,
ActionListener.wrap(tuple -> {
tokenService.createOAuth2Tokens(authentication, originatingAuthentication,
tokenMeta, true, ActionListener.wrap(tuple -> {
final String tokenString = tokenService.getAccessTokenAsString(tuple.v1());
final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse(
new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn));
}, listener::onFailure), tokenMeta, true);
}, listener::onFailure));
}, e -> {
logger.debug(() -> new ParameterizedMessage("SamlToken [{}] could not be authenticated", saml), e);
listener.onFailure(e);

View File

@ -91,7 +91,7 @@ public final class TransportSamlInvalidateSessionAction
return;
}
tokenService.findActiveTokensForRealm(realm.name(), ActionListener.wrap(tokens -> {
tokenService.findActiveTokensForRealm(realm.name(), containsMetadata(tokenMetadata), ActionListener.wrap(tokens -> {
logger.debug("Found [{}] token pairs to invalidate for SAML metadata [{}]", tokens.size(), tokenMetadata);
if (tokens.isEmpty()) {
listener.onResponse(0);
@ -101,7 +101,7 @@ public final class TransportSamlInvalidateSessionAction
tokens.forEach(tuple -> invalidateTokenPair(tuple, groupedListener));
}
}, listener::onFailure
), containsMetadata(tokenMetadata));
));
}
private void invalidateTokenPair(Tuple<UserToken, String> tokenPair, ActionListener<TokensInvalidationResult> listener) {

View File

@ -28,7 +28,6 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
import org.opensaml.saml.saml2.core.LogoutRequest;
import java.io.IOException;
import java.util.Map;
/**
@ -73,7 +72,7 @@ public final class TransportSamlLogoutAction
));
}, listener::onFailure
));
} catch (IOException | ElasticsearchException e) {
} catch (ElasticsearchException e) {
logger.debug("Internal exception during SAML logout", e);
listener.onFailure(e);
}

View File

@ -21,7 +21,6 @@ import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken
import org.elasticsearch.xpack.security.authc.AuthenticationService;
import org.elasticsearch.xpack.security.authc.TokenService;
import java.io.IOException;
import java.util.Collections;
/**
@ -86,19 +85,15 @@ public final class TransportCreateTokenAction extends HandledTransportAction<Cre
}
private void createToken(CreateTokenRequest request, Authentication authentication, Authentication originatingAuth,
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
try {
tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response =
new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope, tuple.v2());
listener.onResponse(response);
}, listener::onFailure), Collections.emptyMap(), includeRefreshToken);
} catch (IOException e) {
listener.onFailure(e);
}
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
tokenService.createOAuth2Tokens(authentication, originatingAuth, Collections.emptyMap(), includeRefreshToken,
ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response = new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope,
tuple.v2());
listener.onResponse(response);
}, listener::onFailure));
}
static String getResponseScopeValue(String requestScope) {

View File

@ -16,6 +16,7 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import java.io.IOException;
import java.time.DateTimeException;
import java.time.Instant;
import java.util.Base64;
import java.util.Collections;
@ -140,17 +141,31 @@ public final class UserToken implements Writeable, ToXContentObject {
return builder.endObject();
}
static UserToken fromSourceMap(Map<String, Object> source) throws IOException {
static UserToken fromSourceMap(Map<String, Object> source) throws IllegalStateException, DateTimeException {
final String id = (String) source.get("id");
if (id == null) {
throw new IllegalStateException("user token source document does not have the \"id\" field");
}
final Long expirationEpochMilli = (Long) source.get("expiration_time");
if (expirationEpochMilli == null) {
throw new IllegalStateException("user token source document does not have the \"expiration_time\" field");
}
final Integer versionId = (Integer) source.get("version");
if (versionId == null) {
throw new IllegalStateException("user token source document does not have the \"version\" field");
}
final Map<String, Object> metadata = (Map<String, Object>) source.get("metadata");
final String authString = (String) source.get("authentication");
if (authString == null) {
throw new IllegalStateException("user token source document does not have the \"authentication\" field");
}
final Version version = Version.fromId(versionId);
try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) {
in.setVersion(version);
Authentication authentication = new Authentication(in);
return new UserToken(id, version, authentication, Instant.ofEpochMilli(expirationEpochMilli), metadata);
} catch (IOException e) {
throw new IllegalStateException("user token source document contains malformed \"authentication\" field", e);
}
}
}

View File

@ -356,7 +356,7 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
new RealmRef("native", NativeRealmSettings.TYPE, "node01"), null);
final Map<String, Object> metadata = samlRealm.createTokenMetadata(nameId, session);
final PlainActionFuture<Tuple<UserToken, String>> future = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, future, metadata, true);
tokenService.createOAuth2Tokens(authentication, authentication, metadata, true, future);
return future.actionGet();
}

View File

@ -239,7 +239,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
new SamlNameId(NameID.TRANSIENT, nameId, null, null, null), session);
final PlainActionFuture<Tuple<UserToken, String>> future = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true);
tokenService.createOAuth2Tokens(authentication, authentication, tokenMetaData, true, future);
final UserToken userToken = future.actionGet().v1();
mockGetTokenFromId(userToken, false, client);
final String tokenString = tokenService.getAccessTokenAsString(userToken);

View File

@ -25,8 +25,10 @@ import org.elasticsearch.action.update.UpdateAction;
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.node.Node;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.ESTestCase;
@ -110,7 +112,8 @@ public class TransportCreateTokenActionTests extends ESTestCase {
doAnswer(invocationOnMock -> {
idxReqReference.set((IndexRequest) invocationOnMock.getArguments()[1]);
ActionListener<IndexResponse> responseActionListener = (ActionListener<IndexResponse>) invocationOnMock.getArguments()[2];
responseActionListener.onResponse(new IndexResponse());
responseActionListener.onResponse(new IndexResponse(new ShardId(".security", UUIDs.randomBase64UUID(), randomInt()), "_doc",
randomAlphaOfLength(4), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), true));
return null;
}).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class));

View File

@ -42,6 +42,7 @@ import org.elasticsearch.env.Environment;
import org.elasticsearch.env.TestEnvironment;
import org.elasticsearch.index.get.GetResult;
import org.elasticsearch.index.seqno.SequenceNumbers;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.rest.RestRequest;
import org.elasticsearch.rest.RestStatus;
@ -195,7 +196,8 @@ public class AuthenticationServiceTests extends ESTestCase {
.thenReturn(new UpdateRequestBuilder(client, UpdateAction.INSTANCE));
doAnswer(invocationOnMock -> {
ActionListener<IndexResponse> responseActionListener = (ActionListener<IndexResponse>) invocationOnMock.getArguments()[2];
responseActionListener.onResponse(new IndexResponse());
responseActionListener.onResponse(new IndexResponse(new ShardId(".security", UUIDs.randomBase64UUID(), randomInt()), "_doc",
randomAlphaOfLength(4), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), true));
return null;
}).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class));
doAnswer(invocationOnMock -> {
@ -1107,7 +1109,7 @@ public class AuthenticationServiceTests extends ESTestCase {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(expected, originatingAuth, Collections.emptyMap(), true, tokenFuture);
}
String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1());
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
@ -1190,7 +1192,7 @@ public class AuthenticationServiceTests extends ESTestCase {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
try (ThreadContext.StoredContext ctx = threadContext.stashContext()) {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(expected, originatingAuth, Collections.emptyMap(), true, tokenFuture);
}
String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1());
mockGetTokenFromId(tokenFuture.get().v1(), true, client);

View File

@ -23,6 +23,7 @@ import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -114,7 +115,8 @@ public class TokenServiceTests extends ESTestCase {
.thenReturn(new UpdateRequestBuilder(client, UpdateAction.INSTANCE));
doAnswer(invocationOnMock -> {
ActionListener<IndexResponse> responseActionListener = (ActionListener<IndexResponse>) invocationOnMock.getArguments()[2];
responseActionListener.onResponse(new IndexResponse());
responseActionListener.onResponse(new IndexResponse(new ShardId(".security", UUIDs.randomBase64UUID(), randomInt()), "_doc",
randomAlphaOfLength(4), randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), true));
return null;
}).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class));
@ -152,7 +154,7 @@ public class TokenServiceTests extends ESTestCase {
TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, false);
@ -199,7 +201,7 @@ public class TokenServiceTests extends ESTestCase {
TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, false);
@ -224,7 +226,7 @@ public class TokenServiceTests extends ESTestCase {
}
PlainActionFuture<Tuple<UserToken, String>> newTokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, newTokenFuture);
final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken);
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));
@ -254,12 +256,11 @@ public class TokenServiceTests extends ESTestCase {
for (int i = 0; i < numRotations; i++) {
rotateKeys(tokenService);
}
TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex,
clusterService);
TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
otherTokenService.refreshMetaData(tokenService.getTokenMetaData());
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, false);
@ -290,7 +291,7 @@ public class TokenServiceTests extends ESTestCase {
TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, false);
@ -321,7 +322,7 @@ public class TokenServiceTests extends ESTestCase {
}
PlainActionFuture<Tuple<UserToken, String>> newTokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, newTokenFuture);
final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken);
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));
@ -352,7 +353,7 @@ public class TokenServiceTests extends ESTestCase {
TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, false);
@ -383,7 +384,7 @@ public class TokenServiceTests extends ESTestCase {
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
UserToken token = tokenFuture.get().v1();
assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue());
@ -397,7 +398,7 @@ public class TokenServiceTests extends ESTestCase {
new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token, true);
@ -451,7 +452,7 @@ public class TokenServiceTests extends ESTestCase {
TokenService tokenService = new TokenService(tokenServiceEnabledSettings, clock, client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
mockGetTokenFromId(token, false);
authentication = token.getAuthentication();
@ -503,7 +504,7 @@ public class TokenServiceTests extends ESTestCase {
.build(),
Clock.systemUTC(), client, securityIndex, clusterService);
IllegalStateException e = expectThrows(IllegalStateException.class,
() -> tokenService.createUserToken(null, null, null, null, true));
() -> tokenService.createOAuth2Tokens(null, null, null, true, null));
assertEquals("tokens are not enabled", e.getMessage());
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -561,7 +562,7 @@ public class TokenServiceTests extends ESTestCase {
new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
tokenService.createOAuth2Tokens(authentication, authentication, Collections.emptyMap(), true, tokenFuture);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
//mockGetTokenFromId(token, false);