From 7ed9d52824f466ddfecedc85e73d5bc8827a8fa7 Mon Sep 17 00:00:00 2001 From: Ioannis Kakavas Date: Tue, 5 Mar 2019 14:55:59 +0200 Subject: [PATCH] Support concurrent refresh of refresh tokens (#39647) This is a backport of #39631 Co-authored-by: Jay Modi jaymode@users.noreply.github.com This change adds support for the concurrent refresh of access tokens as described in #36872 In short it allows subsequent client requests to refresh the same token that come within a predefined window of 60 seconds to be handled as duplicates of the original one and thus receive the same response with the same newly issued access token and refresh token. In order to support that, two new fields are added in the token document. One contains the instant (in epoqueMillis) when a given refresh token is refreshed and one that contains a pointer to the token document that stores the new refresh token and access token that was created by the original refresh. A side effect of this change, that was however also a intended enhancement for the token service, is that we needed to stop encrypting the string representation of the UserToken while serializing. ( It was necessary as we correctly used a new IV for every time we encrypted a token in serialization, so subsequent serializations of the same exact UserToken would produce different access token strings) This change also handles the serialization/deserialization BWC logic: In mixed clusters we keep creating tokens in the old format and consume only old format tokens In upgraded clusters, we start creating tokens in the new format but still remain able to consume old format tokens (that could have been created during the rolling upgrade and are still valid) When reading/writing TokensInvalidationResult objects, we take into consideration that pre 7.1.0 these contained an integer field that carried the attempt count Resolves #36872 --- .../support/TokensInvalidationResult.java | 19 +- .../resources/security-index-template.json | 7 + .../token/InvalidateTokenResponseTests.java | 9 +- .../saml/TransportSamlAuthenticateAction.java | 2 +- .../token/TransportCreateTokenAction.java | 2 +- .../token/TransportRefreshTokenAction.java | 2 +- .../xpack/security/authc/TokenService.java | 789 ++++++++++++------ .../saml/TransportSamlLogoutActionTests.java | 2 +- .../authc/AuthenticationServiceTests.java | 4 +- .../security/authc/TokenAuthIntegTests.java | 103 ++- .../security/authc/TokenServiceTests.java | 68 +- .../TokensInvalidationResultTests.java | 8 +- 12 files changed, 710 insertions(+), 305 deletions(-) diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java index f6e7965d963..f9985dfba7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.security.authc.support; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -32,10 +33,9 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { private final List invalidatedTokens; private final List previouslyInvalidatedTokens; private final List errors; - private final int attemptCount; public TokensInvalidationResult(List invalidatedTokens, List previouslyInvalidatedTokens, - @Nullable List errors, int attemptCount) { + @Nullable List errors) { Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); this.invalidatedTokens = invalidatedTokens; Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); @@ -45,18 +45,19 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { } else { this.errors = Collections.emptyList(); } - this.attemptCount = attemptCount; } public TokensInvalidationResult(StreamInput in) throws IOException { this.invalidatedTokens = in.readStringList(); this.previouslyInvalidatedTokens = in.readStringList(); this.errors = in.readList(StreamInput::readException); - this.attemptCount = in.readVInt(); + if (in.getVersion().before(Version.V_7_1_0)) { + in.readVInt(); + } } public static TokensInvalidationResult emptyResult() { - return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); + return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); } @@ -72,10 +73,6 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { return errors; } - public int getAttemptCount() { - return attemptCount; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject() @@ -100,6 +97,8 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { out.writeStringCollection(invalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens); out.writeCollection(errors, StreamOutput::writeException); - out.writeVInt(attemptCount); + if (out.getVersion().before(Version.V_7_1_0)) { + out.writeVInt(5); + } } } diff --git a/x-pack/plugin/core/src/main/resources/security-index-template.json b/x-pack/plugin/core/src/main/resources/security-index-template.json index 183ffff4ea5..e938464ac6f 100644 --- a/x-pack/plugin/core/src/main/resources/security-index-template.json +++ b/x-pack/plugin/core/src/main/resources/security-index-template.json @@ -199,6 +199,13 @@ "refreshed" : { "type" : "boolean" }, + "refresh_time": { + "type": "date", + "format": "epoch_millis" + }, + "superseded_by": { + "type": "keyword" + }, "invalidated" : { "type" : "boolean" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java index c9c2e470644..bbfba920e38 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java @@ -29,8 +29,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), - randomIntBetween(0, 5)); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); InvalidateTokenResponse response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -47,8 +46,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { } result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), - Arrays.asList(generateRandomStringArray(20, 15, false)), - Collections.emptyList(), randomIntBetween(0, 5)); + Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList()); response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -68,8 +66,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), - randomIntBetween(0, 5)); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); InvalidateTokenResponse response = new InvalidateTokenResponse(result); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java index dee12f4a6bd..0e5acf5394f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java @@ -63,7 +63,7 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio final Map tokenMeta = (Map) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); tokenService.createUserToken(authentication, originatingAuthentication, ActionListener.wrap(tuple -> { - final String tokenString = tokenService.getUserTokenString(tuple.v1()); + final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); final TimeValue expiresIn = tokenService.getExpirationDelay(); listener.onResponse( new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java index 5d5442803e3..75c3ee9df42 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java @@ -89,7 +89,7 @@ public final class TransportCreateTokenAction extends HandledTransportAction listener) { try { tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java index 0eac8d71fb2..71aeb64bc42 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java @@ -31,7 +31,7 @@ public class TransportRefreshTokenAction extends HandledTransportAction listener) { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 36144899d28..b1db11b48f0 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -19,6 +19,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest.OpType; import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -28,6 +29,7 @@ import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.update.UpdateRequest; @@ -67,6 +69,7 @@ import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.XPackField; @@ -113,12 +116,12 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Predicate; @@ -129,6 +132,7 @@ import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.threadpool.ThreadPool.Names.GENERIC; /** * Service responsible for the creation, validation, and other management of {@link UserToken} @@ -155,6 +159,7 @@ public final class TokenService { private static final String MALFORMED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY + "\", error=\"invalid_token\", error_description=\"The access token is malformed\""; private static final String TYPE = "doc"; + private static final BackoffPolicy DEFAULT_BACKOFF = BackoffPolicy.exponentialBackoff(); public static final String THREAD_POOL_NAME = XPackField.SECURITY + "-token-key"; public static final Setting TOKEN_EXPIRATION = Setting.timeSetting("xpack.security.authc.token.timeout", @@ -167,8 +172,7 @@ public final class TokenService { private static final String TOKEN_DOC_TYPE = "token"; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; - private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); - private static final int MAX_RETRY_ATTEMPTS = 5; + static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); private static final Logger logger = LogManager.getLogger(TokenService.class); private final SecureRandom secureRandom = new SecureRandom(); @@ -221,12 +225,22 @@ public final class TokenService { } /** - * Create a token based on the provided authentication and metadata. + * Creates a token based on the provided authentication and metadata with an auto-generated token id. * The created token will be stored in the security index. */ public void createUserToken(Authentication authentication, Authentication originatingClientAuth, ActionListener> listener, Map metadata, boolean includeRefreshToken) throws IOException { + createUserToken(UUIDs.randomBase64UUID(), authentication, originatingClientAuth, listener, metadata, includeRefreshToken); + } + + /** + * Create a token based on the provided authentication and metadata with the given token id. + * The created token will be stored in the security index. + */ + private void createUserToken(String userTokenId, Authentication authentication, Authentication originatingClientAuth, + ActionListener> listener, Map metadata, + boolean includeRefreshToken) throws IOException { ensureEnabled(); if (authentication == null) { listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); @@ -239,7 +253,7 @@ public final class TokenService { final Version version = clusterService.state().nodes().getMinNodeVersion(); final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), authentication.getLookedUpBy(), version, AuthenticationType.TOKEN, authentication.getMetadata()); - final UserToken userToken = new UserToken(version, tokenAuth, expiration, metadata); + final UserToken userToken = new UserToken(userTokenId, version, tokenAuth, expiration, metadata); final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { @@ -280,9 +294,33 @@ public final class TokenService { } } + /** + * Reconstructs the {@link UserToken} from the existing {@code userTokenSource} and call the listener with the {@link UserToken} and the + * refresh token string + */ + private void reIssueTokens(Map userTokenSource, + String refreshToken, ActionListener> listener) { + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + final String id = (String) userTokenSource.get("id"); + final Long expiration = (Long) userTokenSource.get("expiration_time"); + + Version authVersion = Version.fromId(version); + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + Authentication authentication = new Authentication(in); + UserToken userToken = new UserToken(id, authVersion, authentication, Instant.ofEpochMilli(expiration), metadata); + listener.onResponse(new Tuple<>(userToken, refreshToken)); + } catch (IOException e) { + logger.error("Unable to decode existing user token", e); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } + } + /** * Looks in the context to see if the request provided a header with a user token and if so the - * token is validated, which includes authenticated decryption and verification that the token + * token is validated, which might include authenticated decryption and verification that the token * has not been revoked or is expired. */ void getAndValidateToken(ThreadContext ctx, ActionListener listener) { @@ -329,23 +367,78 @@ public final class TokenService { } /** - * Asynchronously decodes the string representation of a {@link UserToken}. The process for - * this is asynchronous as we may need to compute a key, which can be computationally expensive - * so this should not block the current thread, which is typically a network thread. A second - * reason for being asynchronous is that we can restrain the amount of resources consumed by - * the key computation to a single thread. + * Gets the UserToken with given id by fetching the the corresponding token document + */ + void getUserTokenFromId(String userTokenId, ActionListener listener) { + if (securityIndex.isAvailable() == false) { + logger.warn("failed to get token [{}] since index is not available", userTokenId); + listener.onResponse(null); + } else { + securityIndex.checkIndexVersionThenExecute( + ex -> listener.onFailure(traceLog("prepare security index", userTokenId, ex)), + () -> { + final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, + getTokenDocumentId(userTokenId)).request(); + Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", userTokenId, ex)); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, + ActionListener.wrap(response -> { + if (response.isExists()) { + Map accessTokenSource = + (Map) response.getSource().get("access_token"); + if (accessTokenSource == null) { + onFailure.accept(new IllegalStateException( + "token document is missing the access_token field")); + } else if (accessTokenSource.containsKey("user_token") == false) { + onFailure.accept(new IllegalStateException( + "token document is missing the user_token field")); + } else { + Map userTokenSource = + (Map) accessTokenSource.get("user_token"); + listener.onResponse(UserToken.fromSourceMap(userTokenSource)); + } + } else { + onFailure.accept( + new IllegalStateException("token document is missing and must be present")); + } + }, e -> { + // if the index or the shard is not there / available we assume that + // the token is not valid + if (isShardNotAvailableException(e)) { + logger.warn("failed to get token [{}] since index is not available", userTokenId); + listener.onResponse(null); + } else { + logger.error(new ParameterizedMessage("failed to get token [{}]", userTokenId), e); + listener.onFailure(e); + } + }), client::get); + }); + } + } + + /* + * If needed, for tokens that were created in a pre 7.1.0 cluster, it asynchronously decodes the token to get the token document Id. + * The process for this is asynchronous as we may need to compute a key, which can be computationally expensive + * so this should not block the current thread, which is typically a network thread. A second reason for being asynchronous is that + * we can restrain the amount of resources consumed by the key computation to a single thread. + * For tokens created in an after 7.1.0 cluster, the token is just the token document Id so this is used directly without decryption */ void decodeToken(String token, ActionListener listener) throws IOException { // We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key! byte[] bytes = token.getBytes(StandardCharsets.UTF_8); StreamInput in = new InputStreamStreamInput(Base64.getDecoder().wrap(new ByteArrayInputStream(bytes)), bytes.length); - if (in.available() < MINIMUM_BASE64_BYTES) { - logger.debug("invalid token"); - listener.onResponse(null); + final Version version = Version.readVersion(in); + if (version.onOrAfter(Version.V_7_1_0)) { + // The token was created in a > 7.1.0 cluster so it contains the tokenId as a String + String usedTokenId = in.readString(); + getUserTokenFromId(usedTokenId, listener); } else { - // the token exists and the value is at least as long as we'd expect - final Version version = Version.readVersion(in); + // The token was created in a < 7.1.0 cluster so we need to decrypt it to get the tokenId in.setVersion(version); + if (in.available() < MINIMUM_BASE64_BYTES) { + logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BASE64_BYTES); + listener.onResponse(null); + return; + } final BytesKey decodedSalt = new BytesKey(in.readByteArray()); final BytesKey passphraseHash = new BytesKey(in.readByteArray()); KeyAndCache keyAndCache = keyCache.get(passphraseHash); @@ -354,51 +447,8 @@ public final class TokenService { try { final byte[] iv = in.readByteArray(); final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt); - decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> { - if (securityIndex.isAvailable() == false) { - logger.warn("failed to get token [{}] since index is not available", tokenId); - listener.onResponse(null); - } else { - securityIndex.checkIndexVersionThenExecute( - ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)), - () -> { - final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, - getTokenDocumentId(tokenId)).request(); - Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex)); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists()) { - Map accessTokenSource = - (Map) response.getSource().get("access_token"); - if (accessTokenSource == null) { - onFailure.accept(new IllegalStateException( - "token document is missing the access_token field")); - } else if (accessTokenSource.containsKey("user_token") == false) { - onFailure.accept(new IllegalStateException( - "token document is missing the user_token field")); - } else { - Map userTokenSource = - (Map) accessTokenSource.get("user_token"); - listener.onResponse(UserToken.fromSourceMap(userTokenSource)); - } - } else { - onFailure.accept( - new IllegalStateException("token document is missing and must be present")); - } - }, e -> { - // if the index or the shard is not there / available we assume that - // the token is not valid - if (isShardNotAvailableException(e)) { - logger.warn("failed to get token [{}] since index is not available", tokenId); - listener.onResponse(null); - } else { - logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e); - listener.onFailure(e); - } - }), client::get); - }); - } - }, listener::onFailure)); + decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> getUserTokenFromId(tokenId, listener), + listener::onFailure)); } catch (GeneralSecurityException e) { // could happen with a token that is not ours logger.warn("invalid token", e); @@ -442,8 +492,8 @@ public final class TokenService { /** * This method performs the steps necessary to invalidate a token so that it may no longer be - * used. The process of invalidation involves performing an update to - * the token document and setting the invalidated field to true + * used. The process of invalidation involves performing an update to the token document and setting + * the invalidated field to true */ public void invalidateAccessToken(String tokenString, ActionListener listener) { ensureEnabled(); @@ -452,12 +502,13 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("token must be provided")); } else { maybeStartTokenRemover(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); try { decodeToken(tokenString, ActionListener.wrap(userToken -> { if (userToken == null) { listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException())); } else { - indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), + indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null); } }, listener::onFailure)); @@ -480,12 +531,14 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("token must be provided")); } else { maybeStartTokenRemover(); - indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); + indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null); } } /** - * This method performs the steps necessary to invalidate a refresh token so that it may no longer be used. + * This method onvalidates a refresh token so that it may no longer be used. Iinvalidation involves performing an update to the token + * document and setting the refresh_token.invalidated field to true * * @param refreshToken The string representation of the refresh token * @param listener the listener to notify upon completion @@ -497,16 +550,17 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("refresh token must be provided")); } else { maybeStartTokenRemover(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(tuple -> { - final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId()); - indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null); - }, listener::onFailure), new AtomicInteger(0)); + ActionListener.wrap(searchResponse -> { + final String docId = getTokenIdFromDocumentId(searchResponse.getHits().getAt(0).getId()); + indexInvalidation(Collections.singletonList(docId), listener, backoff, "refresh_token", null); + }, listener::onFailure), backoff); } } /** - * Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given + * Invalidates all access tokens and all refresh tokens of a given {@code realmName} and/or of a given * {@code username} so that they may no longer be used * * @param realmName the realm of which the tokens should be invalidated @@ -557,32 +611,30 @@ public final class TokenService { maybeStartTokenRemover(); // Invalidate the refresh tokens first so that they cannot be used to get new // access tokens while we invalidate the access tokens we currently know about + final Iterator backoff = DEFAULT_BACKOFF.iterator(); indexInvalidation(accessTokenIds, ActionListener.wrap(result -> - indexInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()), - "access_token", result), - listener::onFailure), new AtomicInteger(0), "refresh_token", null); + indexInvalidation(accessTokenIds, listener, backoff, "access_token", result), + listener::onFailure), backoff, "refresh_token", null); } /** - * Performs the actual invalidation of a collection of tokens + * Performs the actual invalidation of a collection of tokens. In case of recoverable errors ( see + * {@link TransportActions#isShardNotAvailableException} ) the UpdateRequests to mark the tokens as invalidated are retried using + * an exponential backoff policy. * * @param tokenIds the tokens to invalidate * @param listener the listener to notify upon completion - * @param attemptCount the number of attempts to invalidate that have already been tried + * @param backoff the amount of time to delay between attempts * @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on * what type of tokens should be invalidated * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating * tokens up to the point of the retry. This result is added to the result of the current attempt */ private void indexInvalidation(Collection tokenIds, ActionListener listener, - AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { + Iterator backoff, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { if (tokenIds.isEmpty()) { logger.warn("No [{}] tokens provided for invalidation", srcPrefix); listener.onFailure(invalidGrantException("No tokens provided for invalidation")); - } else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(), - attemptCount.get()); - listener.onFailure(invalidGrantException("failed to invalidate tokens")); } else { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); for (String tokenId : tokenIds) { @@ -627,20 +679,30 @@ public final class TokenService { } } if (retryTokenDocIds.isEmpty() == false) { - TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, - failedRequestResponses, attemptCount.get()); - attemptCount.incrementAndGet(); - indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult); + if (backoff.hasNext()) { + logger.debug("failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too", + retryTokenDocIds.size(), tokenIds.size()); + TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses); + client.threadPool().schedule( + () -> indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult), + backoff.next(), GENERIC); + } else { + logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", + retryTokenDocIds.size(), tokenIds.size()); + } + } else { + TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses); + listener.onResponse(result); } - TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, - failedRequestResponses, attemptCount.get()); - listener.onResponse(result); }, e -> { Throwable cause = ExceptionsHelper.unwrapCause(e); traceLog("invalidate tokens", cause); - if (isShardNotAvailableException(cause)) { - attemptCount.incrementAndGet(); - indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult); + if (isShardNotAvailableException(cause) && backoff.hasNext()) { + logger.debug("failed to invalidate tokens, retrying "); + client.threadPool().schedule( + () -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult), backoff.next(), GENERIC); } else { listener.onFailure(e); } @@ -649,142 +711,272 @@ public final class TokenService { } /** - * Uses the refresh token to refresh its associated token and returns the new token with an - * updated expiration date to the listener + * Called by the transport action in order to start the process of refreshing a token. */ public void refreshToken(String refreshToken, ActionListener> listener) { ensureEnabled(); + final Instant refreshRequested = clock.instant(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(tuple -> { + ActionListener.wrap(searchResponse -> { final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); - final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); - innerRefresh(tokenDocId, clientAuth, listener, tuple.v2()); + final SearchHit tokenDocHit = searchResponse.getHits().getHits()[0]; + final String tokenDocId = tokenDocHit.getId(); + innerRefresh(tokenDocId, tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), tokenDocHit.getPrimaryTerm(), clientAuth, + listener, backoff, refreshRequested); }, listener::onFailure), - new AtomicInteger(0)); + backoff); } - private void findTokenFromRefreshToken(String refreshToken, ActionListener> listener, - AtomicInteger attemptCount) { - if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to find token for refresh token [{}] after [{}] attempts", refreshToken, attemptCount.get()); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else { - SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setQuery(QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) - .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) - .setVersion(true) - .request(); + /** + * Performs an asynchronous search request for the token document that contains the {@code refreshToken} and calls the listener with the + * {@link SearchResponse}. In case of recoverable errors the SearchRequest is retried using an exponential backoff policy. + */ + private void findTokenFromRefreshToken(String refreshToken, ActionListener listener, + Iterator backoff) { + SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setQuery(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) + .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) + .seqNoAndPrimaryTerm(true) + .request(); - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); - if (frozenSecurityIndex.indexExists() == false) { - logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else if (frozenSecurityIndex.isAvailable() == false) { - logger.debug("security index is not available to find token from refresh token, retrying"); - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); - } else { - Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); - securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, - ActionListener.wrap(searchResponse -> { - if (searchResponse.isTimedOut()) { - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); - } else if (searchResponse.getHits().getHits().length < 1) { - logger.info("could not find token document with refresh_token [{}]", refreshToken); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } else if (searchResponse.getHits().getHits().length > 1) { - onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); + final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); + if (frozenSecurityIndex.indexExists() == false) { + logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else if (frozenSecurityIndex.isAvailable() == false) { + logger.debug("security index is not available to find token from refresh token, retrying"); + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); + } else { + Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); + securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, + ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + if (backoff.hasNext()) { + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); } else { - listener.onResponse(new Tuple<>(searchResponse, attemptCount)); + logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else if (searchResponse.getHits().getHits().length < 1) { + logger.warn("could not find token document with refresh_token [{}]", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } else if (searchResponse.getHits().getHits().length > 1) { + onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); + } else { + listener.onResponse(searchResponse); + } + }, e -> { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.debug("failed to find token for refresh token [{}], retrying", refreshToken); + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); + } else { + logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + onFailure.accept(e); + } + }), + client::search)); + } + } + + /** + * Performs the actual refresh of the token with retries in case of certain exceptions that may be recoverable. The + * refresh involves two steps: + * First, we check if the token document is still valid for refresh ({@link TokenService#checkTokenDocForRefresh(Map, Authentication)} + * Then, in the case that the token has been refreshed within the previous 30 seconds (see + * {@link TokenService#checkLenientlyIfTokenAlreadyRefreshed(Map, Authentication)}), we do not create a new token document + * but instead retrieve the one that was created by the original refresh and return a user token and + * refresh token based on that ( see {@link TokenService#reIssueTokens(Map, String, ActionListener)} ). + * Otherwise this token document gets its refresh_token marked as refreshed, while also storing the Instant when it was + * refreshed along with a pointer to the new token document that holds the refresh_token that supersedes this one. The new + * document that contains the new access token and refresh token is created and finally the new access token and refresh token are + * returned to the listener. + */ + private void innerRefresh(String tokenDocId, Map source, long seqNo, long primaryTerm, Authentication clientAuth, + ActionListener> listener, Iterator backoff, Instant refreshRequested) { + logger.debug("Attempting to refresh token [{}]", tokenDocId); + Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); + final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); + if (invalidSource.isPresent()) { + onFailure.accept(invalidSource.get()); + } else { + if (eligibleForMultiRefresh(source, refreshRequested)) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final String supersedingTokenDocId = (String) refreshTokenSrc.get("superseded_by"); + logger.debug("Token document [{}] was recently refreshed, attempting to reuse [{}] for returning an " + + "access token and refresh token", tokenDocId, supersedingTokenDocId); + final ActionListener getSupersedingListener = new ActionListener() { + @Override + public void onResponse(GetResponse response) { + if (response.isExists()) { + logger.debug("Found superseding token document [{}] ", supersedingTokenDocId); + final Map supersedingTokenSource = response.getSource(); + final Map supersedingUserTokenSource = (Map) + ((Map) supersedingTokenSource.get("access_token")).get("user_token"); + final Map supersedingRefreshTokenSrc = + (Map) supersedingTokenSource.get("refresh_token"); + final String supersedingRefreshTokenValue = (String) supersedingRefreshTokenSrc.get("token"); + reIssueTokens(supersedingUserTokenSource, supersedingRefreshTokenValue, listener); + } else if (backoff.hasNext()) { + // We retry this since the creation of the superseding token document might already be in flight but not + // yet completed, triggered by a refresh request that came a few milliseconds ago + logger.info("could not find superseding token document [{}] for token document [{}], retrying", + supersedingTokenDocId, tokenDocId); + client.threadPool().schedule(() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not find superseding token document [{}] for token document [{}] after all retries", + supersedingTokenDocId, tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + + @Override + public void onFailure(Exception e) { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId); + client.threadPool().schedule( + () -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not find token document [{}] for refresh after all retries", supersedingTokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + logger.warn("could not find superseding token document [{}] for refresh", supersedingTokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + }; + getTokenDocAsync(supersedingTokenDocId, getSupersedingListener); + } else { + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + Version authVersion = Version.fromId(version); + Authentication authentication; + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + authentication = new Authentication(in); + } catch (IOException e) { + logger.error("failed to decode the authentication stored with token document [{}]", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + return; + } + final String newUserTokenId = UUIDs.randomBase64UUID(); + final Instant refreshTime = clock.instant(); + Map updateMap = new HashMap<>(); + updateMap.put("refreshed", true); + updateMap.put("refresh_time", refreshTime.toEpochMilli()); + updateMap.put("superseded_by", getTokenDocumentId(newUserTokenId)); + UpdateRequestBuilder updateRequest = + client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) + .setDoc("refresh_token", updateMap) + .setFetchSource(true) + .setRefreshPolicy(RefreshPolicy.IMMEDIATE); + assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number"; + updateRequest.setIfSeqNo(seqNo); + assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term"; + updateRequest.setIfPrimaryTerm(primaryTerm); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), + ActionListener.wrap( + updateResponse -> { + if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + logger.debug("updated the original token document to {}", updateResponse.getGetResult().sourceAsMap()); + createUserToken(newUserTokenId, authentication, clientAuth, listener, metadata, true); + } else if (backoff.hasNext()) { + logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying", + tokenDocId, updateResponse.getResult()); + client.threadPool().schedule( + () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, + refreshRequested), + backoff.next(), GENERIC); + } else { + logger.info("failed to update the original token document [{}] after all retries, " + + "the update result was [{}]. ", tokenDocId, updateResponse.getResult()); + listener.onFailure(invalidGrantException("could not refresh the requested token")); } }, e -> { - if (isShardNotAvailableException(e)) { - logger.debug("failed to search for token document, retrying", e); - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof VersionConflictEngineException) { + //The document has been updated by another thread, get it again. + if (backoff.hasNext()) { + logger.debug("version conflict while updating document [{}], attempting to get it again", + tokenDocId); + final ActionListener getListener = new ActionListener() { + @Override + public void onResponse(GetResponse response) { + if (response.isExists()) { + innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(), + response.getPrimaryTerm(), clientAuth, listener, backoff, refreshRequested); + } else { + logger.warn("could not find token document [{}] for refresh", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + + @Override + public void onFailure(Exception e) { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.info("could not get token document [{}] for refresh, " + + "retrying", tokenDocId); + client.threadPool().schedule( + () -> getTokenDocAsync(tokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not get token document [{}] for refresh after all retries", + tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + onFailure.accept(e); + } + } + }; + getTokenDocAsync(tokenDocId, getListener); + } else { + logger.warn("version conflict while updating document [{}], no retries left", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.debug("failed to update the original token document [{}], retrying", tokenDocId); + client.threadPool().schedule( + () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, + refreshRequested), + backoff.next(), GENERIC); + } else { + logger.warn("failed to update the original token document [{}], after all retries", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } } else { onFailure.accept(e); } }), - client::search)); + client::update); } } } - /** - * Performs the actual refresh of the token with retries in case of certain exceptions that - * may be recoverable. The refresh involves retrieval of the token document and then - * updating the token document to indicate that the document has been refreshed. - */ - private void innerRefresh(String tokenDocId, Authentication clientAuth, ActionListener> listener, - AtomicInteger attemptCount) { - if (attemptCount.getAndIncrement() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to refresh token for doc [{}] after [{}] attempts", tokenDocId, attemptCount.get()); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else { - Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); - GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists()) { - final Map source = response.getSource(); - final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); - - if (invalidSource.isPresent()) { - onFailure.accept(invalidSource.get()); - } else { - final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); - final String authString = (String) userTokenSource.get("authentication"); - final Integer version = (Integer) userTokenSource.get("version"); - final Map metadata = (Map) userTokenSource.get("metadata"); - - Version authVersion = Version.fromId(version); - try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { - in.setVersion(authVersion); - Authentication authentication = new Authentication(in); - UpdateRequestBuilder updateRequest = - client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) - .setDoc("refresh_token", Collections.singletonMap("refreshed", true)) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); - updateRequest.setIfSeqNo(response.getSeqNo()); - updateRequest.setIfPrimaryTerm(response.getPrimaryTerm()); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), - ActionListener.wrap( - updateResponse -> createUserToken(authentication, clientAuth, listener, metadata, true), - e -> { - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof VersionConflictEngineException || - isShardNotAvailableException(e)) { - innerRefresh(tokenDocId, clientAuth, - listener, attemptCount); - } else { - onFailure.accept(e); - } - }), - client::update); - } - } - } else { - logger.info("could not find token document [{}] for refresh", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - }, e -> { - if (isShardNotAvailableException(e)) { - innerRefresh(tokenDocId, clientAuth, listener, attemptCount); - } else { - listener.onFailure(e); - } - }), client::get); - } + private void getTokenDocAsync(String tokenDocId, ActionListener listener) { + GetRequest getRequest = + client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get); } /** * Performs checks on the retrieved source and returns an {@link Optional} with the exception - * if there is an issue + * if there is an issue that makes the retrieved token unsuitable to be refreshed */ private Optional checkTokenDocForRefresh(Map source, Authentication clientAuth) { final Map refreshTokenSrc = (Map) source.get("refresh_token"); @@ -805,8 +997,6 @@ public final class TokenService { return Optional.of(invalidGrantException("token document is missing invalidated value")); } else if (creationEpochMilli == null) { return Optional.of(invalidGrantException("token document is missing creation time value")); - } else if (refreshed) { - return Optional.of(invalidGrantException("token has already been refreshed")); } else if (invalidated) { return Optional.of(invalidGrantException("token has been invalidated")); } else if (clock.instant().isAfter(creationTime.plus(24L, ChronoUnit.HOURS))) { @@ -820,7 +1010,7 @@ public final class TokenService { } else if (userTokenSrc.get("metadata") == null) { return Optional.of(invalidGrantException("token is missing metadata")); } else { - return checkClient(refreshTokenSrc, clientAuth); + return checkLenientlyIfTokenAlreadyRefreshed(source, clientAuth); } } } @@ -830,21 +1020,86 @@ public final class TokenService { if (clientInfo == null) { return Optional.of(invalidGrantException("token is missing client information")); } else if (clientAuth.getUser().principal().equals(clientInfo.get("user")) == false) { + logger.warn("Token was originally created by [{}] but [{}] attempted to refresh it", clientInfo.get("user"), + clientAuth.getUser().principal()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else if (clientAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) { + logger.warn("[{}] created the refresh token while authenticated by [{}] but is now authenticated by [{}]", + clientInfo.get("user"), clientInfo.get("realm"), clientAuth.getAuthenticatedBy().getName()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else { return Optional.empty(); } } + /** + * Checks if the retrieved refresh token is already refreshed taking into consideration that we allow refresh tokens + * to be refreshed multiple times for a very small time window in order to gracefully handle multiple concurrent requests + * from clients + */ + @SuppressWarnings("unchecked") + private Optional checkLenientlyIfTokenAlreadyRefreshed(Map source, + Authentication userAuth) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final Integer version = (Integer) userTokenSource.get("version"); + Version authVersion = Version.fromId(version); + final Boolean refreshed = (Boolean) refreshTokenSrc.get("refreshed"); + if (refreshed) { + if (authVersion.onOrAfter(Version.V_7_1_0)) { + final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); + final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); + final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); + if (supersededBy == null) { + return Optional.of(invalidGrantException("token document is missing superseded by value")); + } else if (refreshTime == null) { + return Optional.of(invalidGrantException("token document is missing refresh time value")); + } else if (clock.instant().isAfter(refreshTime.plus(30L, ChronoUnit.SECONDS))) { + return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); + } + } else { + return Optional.of(invalidGrantException("token has already been refreshed")); + } + } + return checkClient(refreshTokenSrc, userAuth); + } + + /** + * Checks if a refreshed token is eligible to be refreshed again. This is only allowed for versions after 7.1.0 and + * when the refresh_token contains the refresh_time and superseded_by fields and it has been refreshed in a specific + * time period of 60 seconds. The period is defined as 30 seconds before the token was refreshed until 30 seconds after. The + * time window needs to handle instants before the request time as we capture an instant early on in + * {@link TokenService#refreshToken(String, ActionListener)} and in the case of multiple concurrent requests, + * the {@code refreshRequested} when dealing with one of the subsequent requests might well be before the instant when + * the first of the requests refreshed the token. + * + * @param source The source of the token document that contains the originally refreshed token + * @param refreshRequested The instant when the this refresh request was acknowledged by the TokenService + */ + private boolean eligibleForMultiRefresh(Map source, Instant refreshRequested) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final Integer version = (Integer) userTokenSource.get("version"); + Version authVersion = Version.fromId(version); + final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); + final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); + final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); + return authVersion.onOrAfter(Version.V_7_1_0) + && supersededBy != null + && refreshTime != null + && refreshRequested.isBefore(refreshTime.plus(30L, ChronoUnit.SECONDS)) + && refreshRequested.isAfter(refreshTime.minus(30L, ChronoUnit.SECONDS)); + } + /** * Find stored refresh and access tokens that have not been invalidated or expired, and were issued against - * the specified realm. + * the specified realm. * * @param realmName The name of the realm for which to get the tokens - * @param listener The listener to notify upon completion - * @param filter an optional Predicate to test the source of the found documents against + * @param listener The listener to notify upon completion + * @param filter an optional Predicate to test the source of the found documents against */ public void findActiveTokensForRealm(String realmName, ActionListener>> listener, @Nullable Predicate> filter) { @@ -893,7 +1148,6 @@ public final class TokenService { */ public void findActiveTokensForUser(String username, ActionListener>> listener) { ensureEnabled(); - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); if (Strings.isNullOrEmpty(username)) { listener.onFailure(new IllegalArgumentException("username is required")); @@ -958,17 +1212,15 @@ public final class TokenService { } /** - * * Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token * * @param source The token document source as retrieved * @param filter an optional Predicate to test the source of the UserToken against * @return A {@link Tuple} of access-token and refresh-token-id or null if a Predicate is defined and the userToken source doesn't - * satisfy it + * satisfy it */ private Tuple parseTokensFromDocument(Map source, @Nullable Predicate> filter) throws IOException { - final String refreshToken = (String) ((Map) source.get("refresh_token")).get("token"); final Map userTokenSource = (Map) ((Map) source.get("access_token")).get("user_token"); @@ -986,7 +1238,7 @@ public final class TokenService { in.setVersion(authVersion); Authentication authentication = new Authentication(in); return new Tuple<>(new UserToken(id, Version.fromId(version), authentication, Instant.ofEpochMilli(expiration), metadata), - refreshToken); + refreshToken); } } @@ -1063,7 +1315,6 @@ public final class TokenService { } } - public TimeValue getExpirationDelay() { return expirationDelay; } @@ -1088,35 +1339,48 @@ public final class TokenService { private String getFromHeader(ThreadContext threadContext) { String header = threadContext.getHeader("Authorization"); if (Strings.hasText(header) && header.regionMatches(true, 0, "Bearer ", 0, "Bearer ".length()) - && header.length() > "Bearer ".length()) { + && header.length() > "Bearer ".length()) { return header.substring("Bearer ".length()); } return null; } /** - * Serializes a token to a String containing an encrypted representation of the token + * Serializes a token to a String containing the version of the node that created the token and + * either an encrypted representation of the token id for versions earlier to 7.0.0 or the token ie + * itself for versions after 7.0.0 */ - public String getUserTokenString(UserToken userToken) throws IOException, GeneralSecurityException { - // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly - try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); - KeyAndCache keyAndCache = keyCache.activeKeyCache; - Version.writeVersion(userToken.getVersion(), out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = getNewInitializationVector(); - out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = - new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(userToken.getVersion()); - encryptedStreamOutput.writeString(userToken.getId()); - encryptedStreamOutput.close(); + public String getAccessTokenAsString(UserToken userToken) throws IOException, GeneralSecurityException { + if (clusterService.state().nodes().getMinNodeVersion().onOrAfter(Version.V_7_1_0)) { + try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); + Version.writeVersion(userToken.getVersion(), out); + out.writeString(userToken.getId()); return new String(os.toByteArray(), StandardCharsets.UTF_8); } + } else { + // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly + try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); + KeyAndCache keyAndCache = keyCache.activeKeyCache; + Version.writeVersion(userToken.getVersion(), out); + out.writeByteArray(keyAndCache.getSalt().bytes); + out.writeByteArray(keyAndCache.getKeyHash().bytes); + final byte[] initializationVector = getNewInitializationVector(); + out.writeByteArray(initializationVector); + try (CipherOutputStream encryptedOutput = + new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); + StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(userToken.getVersion()); + encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.close(); + return new String(os.toByteArray(), StandardCharsets.UTF_8); + } + } } } @@ -1125,7 +1389,8 @@ public final class TokenService { SecretKeyFactory.getInstance(KDF_ALGORITHM); } - private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { + // Package private for testing + Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); BytesKey salt = keyAndCache.getSalt(); try { @@ -1147,7 +1412,8 @@ public final class TokenService { return cipher; } - private byte[] getNewInitializationVector() { + // Package private for testing + byte[] getNewInitializationVector() { final byte[] initializationVector = new byte[IV_BYTES]; secureRandom.nextBytes(initializationVector); return initializationVector; @@ -1158,7 +1424,7 @@ public final class TokenService { * This method is computationally expensive. */ static SecretKey computeSecretKey(char[] rawPassword, byte[] salt) - throws NoSuchAlgorithmException, InvalidKeySpecException { + throws NoSuchAlgorithmException, InvalidKeySpecException { SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(KDF_ALGORITHM); PBEKeySpec keySpec = new PBEKeySpec(rawPassword, salt, ITERATIONS, 128); SecretKey tmp = secretKeyFactory.generateSecret(keySpec); @@ -1172,7 +1438,7 @@ public final class TokenService { */ private static ElasticsearchSecurityException expiredTokenException() { ElasticsearchSecurityException e = - new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); + new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); e.addHeader("WWW-Authenticate", EXPIRED_TOKEN_WWW_AUTH_VALUE); return e; } @@ -1269,8 +1535,8 @@ public final class TokenService { listener.onResponse(computedKey); } catch (ExecutionException e) { if (e.getCause() != null && - (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException - || e.getCause() instanceof IllegalArgumentException)) { + (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException + || e.getCause() instanceof IllegalArgumentException)) { // this could happen if another realm supports the Bearer token so we should // see if another realm can use this token! logger.debug("unable to decode bearer token", e); @@ -1305,7 +1571,7 @@ public final class TokenService { continue; // collision -- generate a new key } return newTokenMetaData(keyCache.currentTokenKeyHash, Iterables.concat(keyCache.cache.values(), - Collections.singletonList(keyAndCache))); + Collections.singletonList(keyAndCache))); } } return newTokenMetaData(keyCache.currentTokenKeyHash, keyCache.cache.values()); @@ -1335,10 +1601,10 @@ public final class TokenService { KeyAndCache currentKey = keyCache.get(keyCache.currentTokenKeyHash); ArrayList entries = new ArrayList<>(keyCache.cache.values()); Collections.sort(entries, - (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); + (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); for (KeyAndCache value : entries) { if (map.size() < numKeysToKeep || value.keyAndTimestamp.getTimestamp() >= currentKey - .keyAndTimestamp.getTimestamp()) { + .keyAndTimestamp.getTimestamp()) { logger.debug("keeping key {} ", value.getKeyHash()); map.put(value.getKeyHash(), value); } else { @@ -1417,16 +1683,16 @@ public final class TokenService { logger.info("rotate keys on master"); TokenMetaData tokenMetaData = generateSpareKey(); clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction( - ActionListener.wrap((res) -> { - if (res.isAcknowledged()) { - TokenMetaData metaData = rotateToSpareKey(); - clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction(listener, metaData)); - } else { - listener.onFailure(new IllegalStateException("not acked")); - } - }, listener::onFailure), tokenMetaData)); + new TokenMetadataPublishAction( + ActionListener.wrap((res) -> { + if (res.isAcknowledged()) { + TokenMetaData metaData = rotateToSpareKey(); + clusterService.submitStateUpdateTask("publish next key to prepare key rotation", + new TokenMetadataPublishAction(listener, metaData)); + } else { + listener.onFailure(new IllegalStateException("not acked")); + } + }, listener::onFailure), tokenMetaData)); } private final class TokenMetadataPublishAction extends AckedClusterStateUpdateTask { @@ -1528,12 +1794,19 @@ public final class TokenService { } /** - * For testing + * Package private for testing */ void clearActiveKeyCache() { this.keyCache.activeKeyCache.keyCache.invalidateAll(); } + /** + * Package private for testing + */ + KeyAndCache getActiveKeyCache() { + return this.keyCache.activeKeyCache; + } + static final class KeyAndCache implements Closeable { private final KeyAndTimestamp keyAndTimestamp; private final Cache keyCache; @@ -1543,9 +1816,9 @@ public final class TokenService { private KeyAndCache(KeyAndTimestamp keyAndTimestamp, BytesKey salt) { this.keyAndTimestamp = keyAndTimestamp; keyCache = CacheBuilder.builder() - .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) - .setMaximumWeight(500L) - .build(); + .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) + .setMaximumWeight(500L) + .build(); try { SecretKey secretKey = computeSecretKey(keyAndTimestamp.getKey().getChars(), salt.bytes); keyCache.put(salt, secretKey); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 085df140f3e..795cc9fb225 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -242,7 +242,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true); final UserToken userToken = future.actionGet().v1(); mockGetTokenFromId(userToken, false, client); - final String tokenString = tokenService.getUserTokenString(userToken); + final String tokenString = tokenService.getAccessTokenAsString(userToken); final SamlLogoutRequest request = new SamlLogoutRequest(); request.setToken(tokenString); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index 5eee33711e2..cda0586886c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -1109,7 +1109,7 @@ public class AuthenticationServiceTests extends ESTestCase { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getUserTokenString(tokenFuture.get().v1()); + String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); mockGetTokenFromId(tokenFuture.get().v1(), false, client); when(securityIndex.isAvailable()).thenReturn(true); @@ -1192,7 +1192,7 @@ public class AuthenticationServiceTests extends ESTestCase { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getUserTokenString(tokenFuture.get().v1()); + String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); mockGetTokenFromId(tokenFuture.get().v1(), true, client); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[1]).run(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index 61ea4ef9672..7499d8be7d1 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -5,11 +5,14 @@ */ package org.elasticsearch.xpack.security.authc; +import org.apache.directory.api.util.Strings; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.common.settings.SecureString; @@ -23,6 +26,7 @@ import org.elasticsearch.test.SecuritySettingsSource; import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; @@ -38,7 +42,13 @@ import org.junit.Before; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -330,7 +340,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); } - public void testRefreshingMultipleTimes() { + public void testRefreshingMultipleTimesFails() throws Exception { Client client = client().filterWithHeader(Collections.singletonMap("Authorization", UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); @@ -343,12 +353,101 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { assertNotNull(createTokenResponse.getRefreshToken()); CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); assertNotNull(refreshResponse); + // We now have two documents, the original(now refreshed) token doc and the new one with the new access doc + AtomicReference docId = new AtomicReference<>(); + assertBusy(() -> { + SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setSource(SearchSourceBuilder.searchSource() + .query(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("doc_type", "token")) + .must(QueryBuilders.termQuery("refresh_token.refreshed", "true")))) + .setSize(1) + .setTerminateAfter(1) + .get(); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(1L)); + docId.set(searchResponse.getHits().getAt(0).getId()); + }); + // hack doc to modify the refresh time to 50 seconds ago so that we don't hit the lenient refresh case + Instant refreshed = Instant.now(); + Instant aWhileAgo = refreshed.minus(50L, ChronoUnit.SECONDS); + assertTrue(Instant.now().isAfter(aWhileAgo)); + UpdateResponse updateResponse = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get()) + .setDoc("refresh_token", Collections.singletonMap("refresh_time", aWhileAgo.toEpochMilli())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setFetchSource("refresh_token", Strings.EMPTY_STRING) + .get(); + assertNotNull(updateResponse); + Map refreshTokenMap = (Map) updateResponse.getGetResult().sourceAsMap().get("refresh_token"); + assertTrue( + Instant.ofEpochMilli((long) refreshTokenMap.get("refresh_time")).isBefore(Instant.now().minus(30L, ChronoUnit.SECONDS))); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); assertEquals("invalid_grant", e.getMessage()); assertEquals(RestStatus.BAD_REQUEST, e.status()); - assertEquals("token has already been refreshed", e.getHeader("error_description").get(0)); + assertEquals("token has already been refreshed more than 30 seconds in the past", e.getHeader("error_description").get(0)); + } + + public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + Set refreshTokens = new HashSet<>(); + Set accessTokens = new HashSet<>(); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); + final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); + List threads = new ArrayList<>(numberOfThreads); + final CountDownLatch readyLatch = new CountDownLatch(numberOfThreads + 1); + final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads); + AtomicBoolean failed = new AtomicBoolean(); + for (int i = 0; i < numberOfThreads; i++) { + threads.add(new Thread(() -> { + // Each thread gets its own client so that more than one nodes will be hit + Client threadClient = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient threadSecurityClient = new SecurityClient(threadClient); + CreateTokenRequest refreshRequest = + threadSecurityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).request(); + readyLatch.countDown(); + try { + readyLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + completedLatch.countDown(); + return; + } + threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> { + accessTokens.add(result.getTokenString()); + refreshTokens.add(result.getRefreshToken()); + logger.info("received access token [{}] and refresh token [{}]", result.getTokenString(), result.getRefreshToken()); + completedLatch.countDown(); + }, e -> { + failed.set(true); + completedLatch.countDown(); + logger.error("caught exception", e); + })); + })); + } + for (Thread thread : threads) { + thread.start(); + } + readyLatch.countDown(); + readyLatch.await(); + for (Thread thread : threads) { + thread.join(); + } + completedLatch.await(); + assertThat(failed.get(), equalTo(false)); + assertThat(accessTokens.size(), equalTo(1)); + assertThat(refreshTokens.size(), equalTo(1)); } public void testRefreshAsDifferentUser() { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 8caf82e8648..7efb4b51632 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.security.authc; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.get.GetAction; @@ -23,6 +24,8 @@ import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -51,7 +54,11 @@ import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -61,6 +68,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; +import javax.crypto.CipherOutputStream; import javax.crypto.SecretKey; import static java.time.Clock.systemUTC; @@ -151,7 +159,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -198,7 +206,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -219,10 +227,10 @@ public class TokenServiceTests extends ESTestCase { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); + assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { @@ -247,7 +255,7 @@ public class TokenServiceTests extends ESTestCase { rotateKeys(tokenService); } TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, - clusterService); + clusterService); otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); PlainActionFuture> tokenFuture = new PlainActionFuture<>(); @@ -258,7 +266,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); otherTokenService.getAndValidateToken(requestContext, future); @@ -289,7 +297,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -316,7 +324,7 @@ public class TokenServiceTests extends ESTestCase { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); + assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); metaData = tokenService.pruneKeys(1); tokenService.refreshMetaData(metaData); @@ -329,7 +337,7 @@ public class TokenServiceTests extends ESTestCase { } requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -351,7 +359,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -362,8 +370,8 @@ public class TokenServiceTests extends ESTestCase { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // verify a second separate token service with its own passphrase cannot verify - TokenService anotherService = new TokenService(Settings.EMPTY, systemUTC(), client, securityIndex, - clusterService); + TokenService anotherService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, + clusterService); PlainActionFuture future = new PlainActionFuture<>(); anotherService.getAndValidateToken(requestContext, future); assertNull(future.get()); @@ -377,10 +385,10 @@ public class TokenServiceTests extends ESTestCase { PlainActionFuture> tokenFuture = new PlainActionFuture<>(); tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true); UserToken token = tokenFuture.get().v1(); - assertThat(tokenService.getUserTokenString(token), notNullValue()); + assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); tokenService.clearActiveKeyCache(); - assertThat(tokenService.getUserTokenString(token), notNullValue()); + assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); } public void testInvalidatedToken() throws Exception { @@ -395,7 +403,7 @@ public class TokenServiceTests extends ESTestCase { mockGetTokenFromId(token, true); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -449,7 +457,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // the clock is still frozen, so the cookie should be valid @@ -559,7 +567,7 @@ public class TokenServiceTests extends ESTestCase { //mockGetTokenFromId(token, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); doAnswer(invocationOnMock -> { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -598,7 +606,7 @@ public class TokenServiceTests extends ESTestCase { Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); mockGetTokenFromId(expired, false); - String userTokenString = tokenService.getUserTokenString(expired); + String userTokenString = tokenService.getAccessTokenAsString(expired); PlainActionFuture>> authFuture = new PlainActionFuture<>(); tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); Authentication retrievedAuth = authFuture.actionGet().v1(); @@ -639,4 +647,28 @@ public class TokenServiceTests extends ESTestCase { assertEquals(expected.getMetadata(), result.getMetadata()); assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); } + + protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException, + GeneralSecurityException { + try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(Version.V_7_0_0); + TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache(); + Version.writeVersion(Version.V_7_0_0, out); + out.writeByteArray(keyAndCache.getSalt().bytes); + out.writeByteArray(keyAndCache.getKeyHash().bytes); + final byte[] initializationVector = tokenService.getNewInitializationVector(); + out.writeByteArray(initializationVector); + try (CipherOutputStream encryptedOutput = + new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0)); + StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(Version.V_7_0_0); + encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.close(); + return new String(os.toByteArray(), StandardCharsets.UTF_8); + } + } + } + } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java index f180e356b76..55ae297ae4e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java @@ -25,8 +25,7 @@ public class TokensInvalidationResultTests extends ESTestCase { TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Arrays.asList("token3", "token4"), Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), - new ElasticsearchException("boo", new IllegalStateException("far"))), - randomIntBetween(0, 5)); + new ElasticsearchException("boo", new IllegalStateException("far")))); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -56,9 +55,8 @@ public class TokensInvalidationResultTests extends ESTestCase { } public void testToXcontentWithNoErrors() throws Exception{ - TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), - Collections.emptyList(), - Collections.emptyList(), randomIntBetween(0, 5)); + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Collections.emptyList(), + Collections.emptyList()); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); assertThat(Strings.toString(builder),