diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java index f6e7965d963..f9985dfba7a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.security.authc.support; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; import org.elasticsearch.common.Nullable; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -32,10 +33,9 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { private final List invalidatedTokens; private final List previouslyInvalidatedTokens; private final List errors; - private final int attemptCount; public TokensInvalidationResult(List invalidatedTokens, List previouslyInvalidatedTokens, - @Nullable List errors, int attemptCount) { + @Nullable List errors) { Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); this.invalidatedTokens = invalidatedTokens; Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); @@ -45,18 +45,19 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { } else { this.errors = Collections.emptyList(); } - this.attemptCount = attemptCount; } public TokensInvalidationResult(StreamInput in) throws IOException { this.invalidatedTokens = in.readStringList(); this.previouslyInvalidatedTokens = in.readStringList(); this.errors = in.readList(StreamInput::readException); - this.attemptCount = in.readVInt(); + if (in.getVersion().before(Version.V_7_1_0)) { + in.readVInt(); + } } public static TokensInvalidationResult emptyResult() { - return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); + return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList()); } @@ -72,10 +73,6 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { return errors; } - public int getAttemptCount() { - return attemptCount; - } - @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject() @@ -100,6 +97,8 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable { out.writeStringCollection(invalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens); out.writeCollection(errors, StreamOutput::writeException); - out.writeVInt(attemptCount); + if (out.getVersion().before(Version.V_7_1_0)) { + out.writeVInt(5); + } } } diff --git a/x-pack/plugin/core/src/main/resources/security-index-template.json b/x-pack/plugin/core/src/main/resources/security-index-template.json index 183ffff4ea5..e938464ac6f 100644 --- a/x-pack/plugin/core/src/main/resources/security-index-template.json +++ b/x-pack/plugin/core/src/main/resources/security-index-template.json @@ -199,6 +199,13 @@ "refreshed" : { "type" : "boolean" }, + "refresh_time": { + "type": "date", + "format": "epoch_millis" + }, + "superseded_by": { + "type": "keyword" + }, "invalidated" : { "type" : "boolean" }, diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java index c9c2e470644..bbfba920e38 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java @@ -29,8 +29,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), - randomIntBetween(0, 5)); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); InvalidateTokenResponse response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -47,8 +46,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { } result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), - Arrays.asList(generateRandomStringArray(20, 15, false)), - Collections.emptyList(), randomIntBetween(0, 5)); + Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList()); response = new InvalidateTokenResponse(result); try (BytesStreamOutput output = new BytesStreamOutput()) { response.writeTo(output); @@ -68,8 +66,7 @@ public class InvalidateTokenResponseTests extends ESTestCase { List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), - new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), - randomIntBetween(0, 5)); + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2")))); InvalidateTokenResponse response = new InvalidateTokenResponse(result); XContentBuilder builder = XContentFactory.jsonBuilder(); response.toXContent(builder, ToXContent.EMPTY_PARAMS); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java index dee12f4a6bd..0e5acf5394f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlAuthenticateAction.java @@ -63,7 +63,7 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio final Map tokenMeta = (Map) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); tokenService.createUserToken(authentication, originatingAuthentication, ActionListener.wrap(tuple -> { - final String tokenString = tokenService.getUserTokenString(tuple.v1()); + final String tokenString = tokenService.getAccessTokenAsString(tuple.v1()); final TimeValue expiresIn = tokenService.getExpirationDelay(); listener.onResponse( new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java index 5d5442803e3..75c3ee9df42 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java @@ -89,7 +89,7 @@ public final class TransportCreateTokenAction extends HandledTransportAction listener) { try { tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java index 0eac8d71fb2..71aeb64bc42 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java @@ -31,7 +31,7 @@ public class TransportRefreshTokenAction extends HandledTransportAction listener) { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { - final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1()); final String scope = getResponseScopeValue(request.getScope()); final CreateTokenResponse response = diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index 36144899d28..b1db11b48f0 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -19,6 +19,7 @@ import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest.OpType; import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BackoffPolicy; import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkResponse; @@ -28,6 +29,7 @@ import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.support.TransportActions; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.update.UpdateRequest; @@ -67,6 +69,7 @@ import org.elasticsearch.core.internal.io.IOUtils; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.seqno.SequenceNumbers; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchHit; import org.elasticsearch.xpack.core.XPackField; @@ -113,12 +116,12 @@ import java.util.Collection; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; import java.util.function.Predicate; @@ -129,6 +132,7 @@ import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; +import static org.elasticsearch.threadpool.ThreadPool.Names.GENERIC; /** * Service responsible for the creation, validation, and other management of {@link UserToken} @@ -155,6 +159,7 @@ public final class TokenService { private static final String MALFORMED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY + "\", error=\"invalid_token\", error_description=\"The access token is malformed\""; private static final String TYPE = "doc"; + private static final BackoffPolicy DEFAULT_BACKOFF = BackoffPolicy.exponentialBackoff(); public static final String THREAD_POOL_NAME = XPackField.SECURITY + "-token-key"; public static final Setting TOKEN_EXPIRATION = Setting.timeSetting("xpack.security.authc.token.timeout", @@ -167,8 +172,7 @@ public final class TokenService { private static final String TOKEN_DOC_TYPE = "token"; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; - private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); - private static final int MAX_RETRY_ATTEMPTS = 5; + static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); private static final Logger logger = LogManager.getLogger(TokenService.class); private final SecureRandom secureRandom = new SecureRandom(); @@ -221,12 +225,22 @@ public final class TokenService { } /** - * Create a token based on the provided authentication and metadata. + * Creates a token based on the provided authentication and metadata with an auto-generated token id. * The created token will be stored in the security index. */ public void createUserToken(Authentication authentication, Authentication originatingClientAuth, ActionListener> listener, Map metadata, boolean includeRefreshToken) throws IOException { + createUserToken(UUIDs.randomBase64UUID(), authentication, originatingClientAuth, listener, metadata, includeRefreshToken); + } + + /** + * Create a token based on the provided authentication and metadata with the given token id. + * The created token will be stored in the security index. + */ + private void createUserToken(String userTokenId, Authentication authentication, Authentication originatingClientAuth, + ActionListener> listener, Map metadata, + boolean includeRefreshToken) throws IOException { ensureEnabled(); if (authentication == null) { listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); @@ -239,7 +253,7 @@ public final class TokenService { final Version version = clusterService.state().nodes().getMinNodeVersion(); final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), authentication.getLookedUpBy(), version, AuthenticationType.TOKEN, authentication.getMetadata()); - final UserToken userToken = new UserToken(version, tokenAuth, expiration, metadata); + final UserToken userToken = new UserToken(userTokenId, version, tokenAuth, expiration, metadata); final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; try (XContentBuilder builder = XContentFactory.jsonBuilder()) { @@ -280,9 +294,33 @@ public final class TokenService { } } + /** + * Reconstructs the {@link UserToken} from the existing {@code userTokenSource} and call the listener with the {@link UserToken} and the + * refresh token string + */ + private void reIssueTokens(Map userTokenSource, + String refreshToken, ActionListener> listener) { + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + final String id = (String) userTokenSource.get("id"); + final Long expiration = (Long) userTokenSource.get("expiration_time"); + + Version authVersion = Version.fromId(version); + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + Authentication authentication = new Authentication(in); + UserToken userToken = new UserToken(id, authVersion, authentication, Instant.ofEpochMilli(expiration), metadata); + listener.onResponse(new Tuple<>(userToken, refreshToken)); + } catch (IOException e) { + logger.error("Unable to decode existing user token", e); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } + } + /** * Looks in the context to see if the request provided a header with a user token and if so the - * token is validated, which includes authenticated decryption and verification that the token + * token is validated, which might include authenticated decryption and verification that the token * has not been revoked or is expired. */ void getAndValidateToken(ThreadContext ctx, ActionListener listener) { @@ -329,23 +367,78 @@ public final class TokenService { } /** - * Asynchronously decodes the string representation of a {@link UserToken}. The process for - * this is asynchronous as we may need to compute a key, which can be computationally expensive - * so this should not block the current thread, which is typically a network thread. A second - * reason for being asynchronous is that we can restrain the amount of resources consumed by - * the key computation to a single thread. + * Gets the UserToken with given id by fetching the the corresponding token document + */ + void getUserTokenFromId(String userTokenId, ActionListener listener) { + if (securityIndex.isAvailable() == false) { + logger.warn("failed to get token [{}] since index is not available", userTokenId); + listener.onResponse(null); + } else { + securityIndex.checkIndexVersionThenExecute( + ex -> listener.onFailure(traceLog("prepare security index", userTokenId, ex)), + () -> { + final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, + getTokenDocumentId(userTokenId)).request(); + Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", userTokenId, ex)); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, + ActionListener.wrap(response -> { + if (response.isExists()) { + Map accessTokenSource = + (Map) response.getSource().get("access_token"); + if (accessTokenSource == null) { + onFailure.accept(new IllegalStateException( + "token document is missing the access_token field")); + } else if (accessTokenSource.containsKey("user_token") == false) { + onFailure.accept(new IllegalStateException( + "token document is missing the user_token field")); + } else { + Map userTokenSource = + (Map) accessTokenSource.get("user_token"); + listener.onResponse(UserToken.fromSourceMap(userTokenSource)); + } + } else { + onFailure.accept( + new IllegalStateException("token document is missing and must be present")); + } + }, e -> { + // if the index or the shard is not there / available we assume that + // the token is not valid + if (isShardNotAvailableException(e)) { + logger.warn("failed to get token [{}] since index is not available", userTokenId); + listener.onResponse(null); + } else { + logger.error(new ParameterizedMessage("failed to get token [{}]", userTokenId), e); + listener.onFailure(e); + } + }), client::get); + }); + } + } + + /* + * If needed, for tokens that were created in a pre 7.1.0 cluster, it asynchronously decodes the token to get the token document Id. + * The process for this is asynchronous as we may need to compute a key, which can be computationally expensive + * so this should not block the current thread, which is typically a network thread. A second reason for being asynchronous is that + * we can restrain the amount of resources consumed by the key computation to a single thread. + * For tokens created in an after 7.1.0 cluster, the token is just the token document Id so this is used directly without decryption */ void decodeToken(String token, ActionListener listener) throws IOException { // We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key! byte[] bytes = token.getBytes(StandardCharsets.UTF_8); StreamInput in = new InputStreamStreamInput(Base64.getDecoder().wrap(new ByteArrayInputStream(bytes)), bytes.length); - if (in.available() < MINIMUM_BASE64_BYTES) { - logger.debug("invalid token"); - listener.onResponse(null); + final Version version = Version.readVersion(in); + if (version.onOrAfter(Version.V_7_1_0)) { + // The token was created in a > 7.1.0 cluster so it contains the tokenId as a String + String usedTokenId = in.readString(); + getUserTokenFromId(usedTokenId, listener); } else { - // the token exists and the value is at least as long as we'd expect - final Version version = Version.readVersion(in); + // The token was created in a < 7.1.0 cluster so we need to decrypt it to get the tokenId in.setVersion(version); + if (in.available() < MINIMUM_BASE64_BYTES) { + logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BASE64_BYTES); + listener.onResponse(null); + return; + } final BytesKey decodedSalt = new BytesKey(in.readByteArray()); final BytesKey passphraseHash = new BytesKey(in.readByteArray()); KeyAndCache keyAndCache = keyCache.get(passphraseHash); @@ -354,51 +447,8 @@ public final class TokenService { try { final byte[] iv = in.readByteArray(); final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt); - decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> { - if (securityIndex.isAvailable() == false) { - logger.warn("failed to get token [{}] since index is not available", tokenId); - listener.onResponse(null); - } else { - securityIndex.checkIndexVersionThenExecute( - ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)), - () -> { - final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, - getTokenDocumentId(tokenId)).request(); - Consumer onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex)); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists()) { - Map accessTokenSource = - (Map) response.getSource().get("access_token"); - if (accessTokenSource == null) { - onFailure.accept(new IllegalStateException( - "token document is missing the access_token field")); - } else if (accessTokenSource.containsKey("user_token") == false) { - onFailure.accept(new IllegalStateException( - "token document is missing the user_token field")); - } else { - Map userTokenSource = - (Map) accessTokenSource.get("user_token"); - listener.onResponse(UserToken.fromSourceMap(userTokenSource)); - } - } else { - onFailure.accept( - new IllegalStateException("token document is missing and must be present")); - } - }, e -> { - // if the index or the shard is not there / available we assume that - // the token is not valid - if (isShardNotAvailableException(e)) { - logger.warn("failed to get token [{}] since index is not available", tokenId); - listener.onResponse(null); - } else { - logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e); - listener.onFailure(e); - } - }), client::get); - }); - } - }, listener::onFailure)); + decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> getUserTokenFromId(tokenId, listener), + listener::onFailure)); } catch (GeneralSecurityException e) { // could happen with a token that is not ours logger.warn("invalid token", e); @@ -442,8 +492,8 @@ public final class TokenService { /** * This method performs the steps necessary to invalidate a token so that it may no longer be - * used. The process of invalidation involves performing an update to - * the token document and setting the invalidated field to true + * used. The process of invalidation involves performing an update to the token document and setting + * the invalidated field to true */ public void invalidateAccessToken(String tokenString, ActionListener listener) { ensureEnabled(); @@ -452,12 +502,13 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("token must be provided")); } else { maybeStartTokenRemover(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); try { decodeToken(tokenString, ActionListener.wrap(userToken -> { if (userToken == null) { listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException())); } else { - indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), + indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null); } }, listener::onFailure)); @@ -480,12 +531,14 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("token must be provided")); } else { maybeStartTokenRemover(); - indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); + indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null); } } /** - * This method performs the steps necessary to invalidate a refresh token so that it may no longer be used. + * This method onvalidates a refresh token so that it may no longer be used. Iinvalidation involves performing an update to the token + * document and setting the refresh_token.invalidated field to true * * @param refreshToken The string representation of the refresh token * @param listener the listener to notify upon completion @@ -497,16 +550,17 @@ public final class TokenService { listener.onFailure(new IllegalArgumentException("refresh token must be provided")); } else { maybeStartTokenRemover(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(tuple -> { - final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId()); - indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null); - }, listener::onFailure), new AtomicInteger(0)); + ActionListener.wrap(searchResponse -> { + final String docId = getTokenIdFromDocumentId(searchResponse.getHits().getAt(0).getId()); + indexInvalidation(Collections.singletonList(docId), listener, backoff, "refresh_token", null); + }, listener::onFailure), backoff); } } /** - * Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given + * Invalidates all access tokens and all refresh tokens of a given {@code realmName} and/or of a given * {@code username} so that they may no longer be used * * @param realmName the realm of which the tokens should be invalidated @@ -557,32 +611,30 @@ public final class TokenService { maybeStartTokenRemover(); // Invalidate the refresh tokens first so that they cannot be used to get new // access tokens while we invalidate the access tokens we currently know about + final Iterator backoff = DEFAULT_BACKOFF.iterator(); indexInvalidation(accessTokenIds, ActionListener.wrap(result -> - indexInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()), - "access_token", result), - listener::onFailure), new AtomicInteger(0), "refresh_token", null); + indexInvalidation(accessTokenIds, listener, backoff, "access_token", result), + listener::onFailure), backoff, "refresh_token", null); } /** - * Performs the actual invalidation of a collection of tokens + * Performs the actual invalidation of a collection of tokens. In case of recoverable errors ( see + * {@link TransportActions#isShardNotAvailableException} ) the UpdateRequests to mark the tokens as invalidated are retried using + * an exponential backoff policy. * * @param tokenIds the tokens to invalidate * @param listener the listener to notify upon completion - * @param attemptCount the number of attempts to invalidate that have already been tried + * @param backoff the amount of time to delay between attempts * @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on * what type of tokens should be invalidated * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating * tokens up to the point of the retry. This result is added to the result of the current attempt */ private void indexInvalidation(Collection tokenIds, ActionListener listener, - AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { + Iterator backoff, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { if (tokenIds.isEmpty()) { logger.warn("No [{}] tokens provided for invalidation", srcPrefix); listener.onFailure(invalidGrantException("No tokens provided for invalidation")); - } else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(), - attemptCount.get()); - listener.onFailure(invalidGrantException("failed to invalidate tokens")); } else { BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); for (String tokenId : tokenIds) { @@ -627,20 +679,30 @@ public final class TokenService { } } if (retryTokenDocIds.isEmpty() == false) { - TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, - failedRequestResponses, attemptCount.get()); - attemptCount.incrementAndGet(); - indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult); + if (backoff.hasNext()) { + logger.debug("failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too", + retryTokenDocIds.size(), tokenIds.size()); + TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses); + client.threadPool().schedule( + () -> indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult), + backoff.next(), GENERIC); + } else { + logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries", + retryTokenDocIds.size(), tokenIds.size()); + } + } else { + TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses); + listener.onResponse(result); } - TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, - failedRequestResponses, attemptCount.get()); - listener.onResponse(result); }, e -> { Throwable cause = ExceptionsHelper.unwrapCause(e); traceLog("invalidate tokens", cause); - if (isShardNotAvailableException(cause)) { - attemptCount.incrementAndGet(); - indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult); + if (isShardNotAvailableException(cause) && backoff.hasNext()) { + logger.debug("failed to invalidate tokens, retrying "); + client.threadPool().schedule( + () -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult), backoff.next(), GENERIC); } else { listener.onFailure(e); } @@ -649,142 +711,272 @@ public final class TokenService { } /** - * Uses the refresh token to refresh its associated token and returns the new token with an - * updated expiration date to the listener + * Called by the transport action in order to start the process of refreshing a token. */ public void refreshToken(String refreshToken, ActionListener> listener) { ensureEnabled(); + final Instant refreshRequested = clock.instant(); + final Iterator backoff = DEFAULT_BACKOFF.iterator(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(tuple -> { + ActionListener.wrap(searchResponse -> { final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); - final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); - innerRefresh(tokenDocId, clientAuth, listener, tuple.v2()); + final SearchHit tokenDocHit = searchResponse.getHits().getHits()[0]; + final String tokenDocId = tokenDocHit.getId(); + innerRefresh(tokenDocId, tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), tokenDocHit.getPrimaryTerm(), clientAuth, + listener, backoff, refreshRequested); }, listener::onFailure), - new AtomicInteger(0)); + backoff); } - private void findTokenFromRefreshToken(String refreshToken, ActionListener> listener, - AtomicInteger attemptCount) { - if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to find token for refresh token [{}] after [{}] attempts", refreshToken, attemptCount.get()); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else { - SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setQuery(QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) - .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) - .setVersion(true) - .request(); + /** + * Performs an asynchronous search request for the token document that contains the {@code refreshToken} and calls the listener with the + * {@link SearchResponse}. In case of recoverable errors the SearchRequest is retried using an exponential backoff policy. + */ + private void findTokenFromRefreshToken(String refreshToken, ActionListener listener, + Iterator backoff) { + SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setQuery(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) + .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) + .seqNoAndPrimaryTerm(true) + .request(); - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); - if (frozenSecurityIndex.indexExists() == false) { - logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else if (frozenSecurityIndex.isAvailable() == false) { - logger.debug("security index is not available to find token from refresh token, retrying"); - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); - } else { - Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); - securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, - ActionListener.wrap(searchResponse -> { - if (searchResponse.isTimedOut()) { - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); - } else if (searchResponse.getHits().getHits().length < 1) { - logger.info("could not find token document with refresh_token [{}]", refreshToken); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } else if (searchResponse.getHits().getHits().length > 1) { - onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); + final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); + if (frozenSecurityIndex.indexExists() == false) { + logger.warn("security index does not exist therefore refresh token [{}] cannot be validated", refreshToken); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else if (frozenSecurityIndex.isAvailable() == false) { + logger.debug("security index is not available to find token from refresh token, retrying"); + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); + } else { + Consumer onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); + securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, + ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + if (backoff.hasNext()) { + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); } else { - listener.onResponse(new Tuple<>(searchResponse, attemptCount)); + logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else if (searchResponse.getHits().getHits().length < 1) { + logger.warn("could not find token document with refresh_token [{}]", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } else if (searchResponse.getHits().getHits().length > 1) { + onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); + } else { + listener.onResponse(searchResponse); + } + }, e -> { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.debug("failed to find token for refresh token [{}], retrying", refreshToken); + client.threadPool().scheduleWithFixedDelay( + () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC); + } else { + logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + onFailure.accept(e); + } + }), + client::search)); + } + } + + /** + * Performs the actual refresh of the token with retries in case of certain exceptions that may be recoverable. The + * refresh involves two steps: + * First, we check if the token document is still valid for refresh ({@link TokenService#checkTokenDocForRefresh(Map, Authentication)} + * Then, in the case that the token has been refreshed within the previous 30 seconds (see + * {@link TokenService#checkLenientlyIfTokenAlreadyRefreshed(Map, Authentication)}), we do not create a new token document + * but instead retrieve the one that was created by the original refresh and return a user token and + * refresh token based on that ( see {@link TokenService#reIssueTokens(Map, String, ActionListener)} ). + * Otherwise this token document gets its refresh_token marked as refreshed, while also storing the Instant when it was + * refreshed along with a pointer to the new token document that holds the refresh_token that supersedes this one. The new + * document that contains the new access token and refresh token is created and finally the new access token and refresh token are + * returned to the listener. + */ + private void innerRefresh(String tokenDocId, Map source, long seqNo, long primaryTerm, Authentication clientAuth, + ActionListener> listener, Iterator backoff, Instant refreshRequested) { + logger.debug("Attempting to refresh token [{}]", tokenDocId); + Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); + final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); + if (invalidSource.isPresent()) { + onFailure.accept(invalidSource.get()); + } else { + if (eligibleForMultiRefresh(source, refreshRequested)) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final String supersedingTokenDocId = (String) refreshTokenSrc.get("superseded_by"); + logger.debug("Token document [{}] was recently refreshed, attempting to reuse [{}] for returning an " + + "access token and refresh token", tokenDocId, supersedingTokenDocId); + final ActionListener getSupersedingListener = new ActionListener() { + @Override + public void onResponse(GetResponse response) { + if (response.isExists()) { + logger.debug("Found superseding token document [{}] ", supersedingTokenDocId); + final Map supersedingTokenSource = response.getSource(); + final Map supersedingUserTokenSource = (Map) + ((Map) supersedingTokenSource.get("access_token")).get("user_token"); + final Map supersedingRefreshTokenSrc = + (Map) supersedingTokenSource.get("refresh_token"); + final String supersedingRefreshTokenValue = (String) supersedingRefreshTokenSrc.get("token"); + reIssueTokens(supersedingUserTokenSource, supersedingRefreshTokenValue, listener); + } else if (backoff.hasNext()) { + // We retry this since the creation of the superseding token document might already be in flight but not + // yet completed, triggered by a refresh request that came a few milliseconds ago + logger.info("could not find superseding token document [{}] for token document [{}], retrying", + supersedingTokenDocId, tokenDocId); + client.threadPool().schedule(() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not find superseding token document [{}] for token document [{}] after all retries", + supersedingTokenDocId, tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + + @Override + public void onFailure(Exception e) { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId); + client.threadPool().schedule( + () -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not find token document [{}] for refresh after all retries", supersedingTokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + logger.warn("could not find superseding token document [{}] for refresh", supersedingTokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + }; + getTokenDocAsync(supersedingTokenDocId, getSupersedingListener); + } else { + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + Version authVersion = Version.fromId(version); + Authentication authentication; + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + authentication = new Authentication(in); + } catch (IOException e) { + logger.error("failed to decode the authentication stored with token document [{}]", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + return; + } + final String newUserTokenId = UUIDs.randomBase64UUID(); + final Instant refreshTime = clock.instant(); + Map updateMap = new HashMap<>(); + updateMap.put("refreshed", true); + updateMap.put("refresh_time", refreshTime.toEpochMilli()); + updateMap.put("superseded_by", getTokenDocumentId(newUserTokenId)); + UpdateRequestBuilder updateRequest = + client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) + .setDoc("refresh_token", updateMap) + .setFetchSource(true) + .setRefreshPolicy(RefreshPolicy.IMMEDIATE); + assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number"; + updateRequest.setIfSeqNo(seqNo); + assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term"; + updateRequest.setIfPrimaryTerm(primaryTerm); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), + ActionListener.wrap( + updateResponse -> { + if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + logger.debug("updated the original token document to {}", updateResponse.getGetResult().sourceAsMap()); + createUserToken(newUserTokenId, authentication, clientAuth, listener, metadata, true); + } else if (backoff.hasNext()) { + logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying", + tokenDocId, updateResponse.getResult()); + client.threadPool().schedule( + () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, + refreshRequested), + backoff.next(), GENERIC); + } else { + logger.info("failed to update the original token document [{}] after all retries, " + + "the update result was [{}]. ", tokenDocId, updateResponse.getResult()); + listener.onFailure(invalidGrantException("could not refresh the requested token")); } }, e -> { - if (isShardNotAvailableException(e)) { - logger.debug("failed to search for token document, retrying", e); - attemptCount.incrementAndGet(); - findTokenFromRefreshToken(refreshToken, listener, attemptCount); + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof VersionConflictEngineException) { + //The document has been updated by another thread, get it again. + if (backoff.hasNext()) { + logger.debug("version conflict while updating document [{}], attempting to get it again", + tokenDocId); + final ActionListener getListener = new ActionListener() { + @Override + public void onResponse(GetResponse response) { + if (response.isExists()) { + innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(), + response.getPrimaryTerm(), clientAuth, listener, backoff, refreshRequested); + } else { + logger.warn("could not find token document [{}] for refresh", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } + + @Override + public void onFailure(Exception e) { + if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.info("could not get token document [{}] for refresh, " + + "retrying", tokenDocId); + client.threadPool().schedule( + () -> getTokenDocAsync(tokenDocId, this), backoff.next(), GENERIC); + } else { + logger.warn("could not get token document [{}] for refresh after all retries", + tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else { + onFailure.accept(e); + } + } + }; + getTokenDocAsync(tokenDocId, getListener); + } else { + logger.warn("version conflict while updating document [{}], no retries left", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } + } else if (isShardNotAvailableException(e)) { + if (backoff.hasNext()) { + logger.debug("failed to update the original token document [{}], retrying", tokenDocId); + client.threadPool().schedule( + () -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff, + refreshRequested), + backoff.next(), GENERIC); + } else { + logger.warn("failed to update the original token document [{}], after all retries", tokenDocId); + onFailure.accept(invalidGrantException("could not refresh the requested token")); + } } else { onFailure.accept(e); } }), - client::search)); + client::update); } } } - /** - * Performs the actual refresh of the token with retries in case of certain exceptions that - * may be recoverable. The refresh involves retrieval of the token document and then - * updating the token document to indicate that the document has been refreshed. - */ - private void innerRefresh(String tokenDocId, Authentication clientAuth, ActionListener> listener, - AtomicInteger attemptCount) { - if (attemptCount.getAndIncrement() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to refresh token for doc [{}] after [{}] attempts", tokenDocId, attemptCount.get()); - listener.onFailure(invalidGrantException("could not refresh the requested token")); - } else { - Consumer onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); - GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, - ActionListener.wrap(response -> { - if (response.isExists()) { - final Map source = response.getSource(); - final Optional invalidSource = checkTokenDocForRefresh(source, clientAuth); - - if (invalidSource.isPresent()) { - onFailure.accept(invalidSource.get()); - } else { - final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); - final String authString = (String) userTokenSource.get("authentication"); - final Integer version = (Integer) userTokenSource.get("version"); - final Map metadata = (Map) userTokenSource.get("metadata"); - - Version authVersion = Version.fromId(version); - try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { - in.setVersion(authVersion); - Authentication authentication = new Authentication(in); - UpdateRequestBuilder updateRequest = - client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) - .setDoc("refresh_token", Collections.singletonMap("refreshed", true)) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); - updateRequest.setIfSeqNo(response.getSeqNo()); - updateRequest.setIfPrimaryTerm(response.getPrimaryTerm()); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), - ActionListener.wrap( - updateResponse -> createUserToken(authentication, clientAuth, listener, metadata, true), - e -> { - Throwable cause = ExceptionsHelper.unwrapCause(e); - if (cause instanceof VersionConflictEngineException || - isShardNotAvailableException(e)) { - innerRefresh(tokenDocId, clientAuth, - listener, attemptCount); - } else { - onFailure.accept(e); - } - }), - client::update); - } - } - } else { - logger.info("could not find token document [{}] for refresh", tokenDocId); - onFailure.accept(invalidGrantException("could not refresh the requested token")); - } - }, e -> { - if (isShardNotAvailableException(e)) { - innerRefresh(tokenDocId, clientAuth, listener, attemptCount); - } else { - listener.onFailure(e); - } - }), client::get); - } + private void getTokenDocAsync(String tokenDocId, ActionListener listener) { + GetRequest getRequest = + client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get); } /** * Performs checks on the retrieved source and returns an {@link Optional} with the exception - * if there is an issue + * if there is an issue that makes the retrieved token unsuitable to be refreshed */ private Optional checkTokenDocForRefresh(Map source, Authentication clientAuth) { final Map refreshTokenSrc = (Map) source.get("refresh_token"); @@ -805,8 +997,6 @@ public final class TokenService { return Optional.of(invalidGrantException("token document is missing invalidated value")); } else if (creationEpochMilli == null) { return Optional.of(invalidGrantException("token document is missing creation time value")); - } else if (refreshed) { - return Optional.of(invalidGrantException("token has already been refreshed")); } else if (invalidated) { return Optional.of(invalidGrantException("token has been invalidated")); } else if (clock.instant().isAfter(creationTime.plus(24L, ChronoUnit.HOURS))) { @@ -820,7 +1010,7 @@ public final class TokenService { } else if (userTokenSrc.get("metadata") == null) { return Optional.of(invalidGrantException("token is missing metadata")); } else { - return checkClient(refreshTokenSrc, clientAuth); + return checkLenientlyIfTokenAlreadyRefreshed(source, clientAuth); } } } @@ -830,21 +1020,86 @@ public final class TokenService { if (clientInfo == null) { return Optional.of(invalidGrantException("token is missing client information")); } else if (clientAuth.getUser().principal().equals(clientInfo.get("user")) == false) { + logger.warn("Token was originally created by [{}] but [{}] attempted to refresh it", clientInfo.get("user"), + clientAuth.getUser().principal()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else if (clientAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) { + logger.warn("[{}] created the refresh token while authenticated by [{}] but is now authenticated by [{}]", + clientInfo.get("user"), clientInfo.get("realm"), clientAuth.getAuthenticatedBy().getName()); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); } else { return Optional.empty(); } } + /** + * Checks if the retrieved refresh token is already refreshed taking into consideration that we allow refresh tokens + * to be refreshed multiple times for a very small time window in order to gracefully handle multiple concurrent requests + * from clients + */ + @SuppressWarnings("unchecked") + private Optional checkLenientlyIfTokenAlreadyRefreshed(Map source, + Authentication userAuth) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final Integer version = (Integer) userTokenSource.get("version"); + Version authVersion = Version.fromId(version); + final Boolean refreshed = (Boolean) refreshTokenSrc.get("refreshed"); + if (refreshed) { + if (authVersion.onOrAfter(Version.V_7_1_0)) { + final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); + final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); + final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); + if (supersededBy == null) { + return Optional.of(invalidGrantException("token document is missing superseded by value")); + } else if (refreshTime == null) { + return Optional.of(invalidGrantException("token document is missing refresh time value")); + } else if (clock.instant().isAfter(refreshTime.plus(30L, ChronoUnit.SECONDS))) { + return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past")); + } + } else { + return Optional.of(invalidGrantException("token has already been refreshed")); + } + } + return checkClient(refreshTokenSrc, userAuth); + } + + /** + * Checks if a refreshed token is eligible to be refreshed again. This is only allowed for versions after 7.1.0 and + * when the refresh_token contains the refresh_time and superseded_by fields and it has been refreshed in a specific + * time period of 60 seconds. The period is defined as 30 seconds before the token was refreshed until 30 seconds after. The + * time window needs to handle instants before the request time as we capture an instant early on in + * {@link TokenService#refreshToken(String, ActionListener)} and in the case of multiple concurrent requests, + * the {@code refreshRequested} when dealing with one of the subsequent requests might well be before the instant when + * the first of the requests refreshed the token. + * + * @param source The source of the token document that contains the originally refreshed token + * @param refreshRequested The instant when the this refresh request was acknowledged by the TokenService + */ + private boolean eligibleForMultiRefresh(Map source, Instant refreshRequested) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final Integer version = (Integer) userTokenSource.get("version"); + Version authVersion = Version.fromId(version); + final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time"); + final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli); + final String supersededBy = (String) refreshTokenSrc.get("superseded_by"); + return authVersion.onOrAfter(Version.V_7_1_0) + && supersededBy != null + && refreshTime != null + && refreshRequested.isBefore(refreshTime.plus(30L, ChronoUnit.SECONDS)) + && refreshRequested.isAfter(refreshTime.minus(30L, ChronoUnit.SECONDS)); + } + /** * Find stored refresh and access tokens that have not been invalidated or expired, and were issued against - * the specified realm. + * the specified realm. * * @param realmName The name of the realm for which to get the tokens - * @param listener The listener to notify upon completion - * @param filter an optional Predicate to test the source of the found documents against + * @param listener The listener to notify upon completion + * @param filter an optional Predicate to test the source of the found documents against */ public void findActiveTokensForRealm(String realmName, ActionListener>> listener, @Nullable Predicate> filter) { @@ -893,7 +1148,6 @@ public final class TokenService { */ public void findActiveTokensForUser(String username, ActionListener>> listener) { ensureEnabled(); - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); if (Strings.isNullOrEmpty(username)) { listener.onFailure(new IllegalArgumentException("username is required")); @@ -958,17 +1212,15 @@ public final class TokenService { } /** - * * Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token * * @param source The token document source as retrieved * @param filter an optional Predicate to test the source of the UserToken against * @return A {@link Tuple} of access-token and refresh-token-id or null if a Predicate is defined and the userToken source doesn't - * satisfy it + * satisfy it */ private Tuple parseTokensFromDocument(Map source, @Nullable Predicate> filter) throws IOException { - final String refreshToken = (String) ((Map) source.get("refresh_token")).get("token"); final Map userTokenSource = (Map) ((Map) source.get("access_token")).get("user_token"); @@ -986,7 +1238,7 @@ public final class TokenService { in.setVersion(authVersion); Authentication authentication = new Authentication(in); return new Tuple<>(new UserToken(id, Version.fromId(version), authentication, Instant.ofEpochMilli(expiration), metadata), - refreshToken); + refreshToken); } } @@ -1063,7 +1315,6 @@ public final class TokenService { } } - public TimeValue getExpirationDelay() { return expirationDelay; } @@ -1088,35 +1339,48 @@ public final class TokenService { private String getFromHeader(ThreadContext threadContext) { String header = threadContext.getHeader("Authorization"); if (Strings.hasText(header) && header.regionMatches(true, 0, "Bearer ", 0, "Bearer ".length()) - && header.length() > "Bearer ".length()) { + && header.length() > "Bearer ".length()) { return header.substring("Bearer ".length()); } return null; } /** - * Serializes a token to a String containing an encrypted representation of the token + * Serializes a token to a String containing the version of the node that created the token and + * either an encrypted representation of the token id for versions earlier to 7.0.0 or the token ie + * itself for versions after 7.0.0 */ - public String getUserTokenString(UserToken userToken) throws IOException, GeneralSecurityException { - // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly - try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); - OutputStream base64 = Base64.getEncoder().wrap(os); - StreamOutput out = new OutputStreamStreamOutput(base64)) { - out.setVersion(userToken.getVersion()); - KeyAndCache keyAndCache = keyCache.activeKeyCache; - Version.writeVersion(userToken.getVersion(), out); - out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); - final byte[] initializationVector = getNewInitializationVector(); - out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = - new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); - StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { - encryptedStreamOutput.setVersion(userToken.getVersion()); - encryptedStreamOutput.writeString(userToken.getId()); - encryptedStreamOutput.close(); + public String getAccessTokenAsString(UserToken userToken) throws IOException, GeneralSecurityException { + if (clusterService.state().nodes().getMinNodeVersion().onOrAfter(Version.V_7_1_0)) { + try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); + Version.writeVersion(userToken.getVersion(), out); + out.writeString(userToken.getId()); return new String(os.toByteArray(), StandardCharsets.UTF_8); } + } else { + // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly + try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); + KeyAndCache keyAndCache = keyCache.activeKeyCache; + Version.writeVersion(userToken.getVersion(), out); + out.writeByteArray(keyAndCache.getSalt().bytes); + out.writeByteArray(keyAndCache.getKeyHash().bytes); + final byte[] initializationVector = getNewInitializationVector(); + out.writeByteArray(initializationVector); + try (CipherOutputStream encryptedOutput = + new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); + StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(userToken.getVersion()); + encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.close(); + return new String(os.toByteArray(), StandardCharsets.UTF_8); + } + } } } @@ -1125,7 +1389,8 @@ public final class TokenService { SecretKeyFactory.getInstance(KDF_ALGORITHM); } - private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { + // Package private for testing + Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); BytesKey salt = keyAndCache.getSalt(); try { @@ -1147,7 +1412,8 @@ public final class TokenService { return cipher; } - private byte[] getNewInitializationVector() { + // Package private for testing + byte[] getNewInitializationVector() { final byte[] initializationVector = new byte[IV_BYTES]; secureRandom.nextBytes(initializationVector); return initializationVector; @@ -1158,7 +1424,7 @@ public final class TokenService { * This method is computationally expensive. */ static SecretKey computeSecretKey(char[] rawPassword, byte[] salt) - throws NoSuchAlgorithmException, InvalidKeySpecException { + throws NoSuchAlgorithmException, InvalidKeySpecException { SecretKeyFactory secretKeyFactory = SecretKeyFactory.getInstance(KDF_ALGORITHM); PBEKeySpec keySpec = new PBEKeySpec(rawPassword, salt, ITERATIONS, 128); SecretKey tmp = secretKeyFactory.generateSecret(keySpec); @@ -1172,7 +1438,7 @@ public final class TokenService { */ private static ElasticsearchSecurityException expiredTokenException() { ElasticsearchSecurityException e = - new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); + new ElasticsearchSecurityException("token expired", RestStatus.UNAUTHORIZED); e.addHeader("WWW-Authenticate", EXPIRED_TOKEN_WWW_AUTH_VALUE); return e; } @@ -1269,8 +1535,8 @@ public final class TokenService { listener.onResponse(computedKey); } catch (ExecutionException e) { if (e.getCause() != null && - (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException - || e.getCause() instanceof IllegalArgumentException)) { + (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException + || e.getCause() instanceof IllegalArgumentException)) { // this could happen if another realm supports the Bearer token so we should // see if another realm can use this token! logger.debug("unable to decode bearer token", e); @@ -1305,7 +1571,7 @@ public final class TokenService { continue; // collision -- generate a new key } return newTokenMetaData(keyCache.currentTokenKeyHash, Iterables.concat(keyCache.cache.values(), - Collections.singletonList(keyAndCache))); + Collections.singletonList(keyAndCache))); } } return newTokenMetaData(keyCache.currentTokenKeyHash, keyCache.cache.values()); @@ -1335,10 +1601,10 @@ public final class TokenService { KeyAndCache currentKey = keyCache.get(keyCache.currentTokenKeyHash); ArrayList entries = new ArrayList<>(keyCache.cache.values()); Collections.sort(entries, - (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); + (left, right) -> Long.compare(right.keyAndTimestamp.getTimestamp(), left.keyAndTimestamp.getTimestamp())); for (KeyAndCache value : entries) { if (map.size() < numKeysToKeep || value.keyAndTimestamp.getTimestamp() >= currentKey - .keyAndTimestamp.getTimestamp()) { + .keyAndTimestamp.getTimestamp()) { logger.debug("keeping key {} ", value.getKeyHash()); map.put(value.getKeyHash(), value); } else { @@ -1417,16 +1683,16 @@ public final class TokenService { logger.info("rotate keys on master"); TokenMetaData tokenMetaData = generateSpareKey(); clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction( - ActionListener.wrap((res) -> { - if (res.isAcknowledged()) { - TokenMetaData metaData = rotateToSpareKey(); - clusterService.submitStateUpdateTask("publish next key to prepare key rotation", - new TokenMetadataPublishAction(listener, metaData)); - } else { - listener.onFailure(new IllegalStateException("not acked")); - } - }, listener::onFailure), tokenMetaData)); + new TokenMetadataPublishAction( + ActionListener.wrap((res) -> { + if (res.isAcknowledged()) { + TokenMetaData metaData = rotateToSpareKey(); + clusterService.submitStateUpdateTask("publish next key to prepare key rotation", + new TokenMetadataPublishAction(listener, metaData)); + } else { + listener.onFailure(new IllegalStateException("not acked")); + } + }, listener::onFailure), tokenMetaData)); } private final class TokenMetadataPublishAction extends AckedClusterStateUpdateTask { @@ -1528,12 +1794,19 @@ public final class TokenService { } /** - * For testing + * Package private for testing */ void clearActiveKeyCache() { this.keyCache.activeKeyCache.keyCache.invalidateAll(); } + /** + * Package private for testing + */ + KeyAndCache getActiveKeyCache() { + return this.keyCache.activeKeyCache; + } + static final class KeyAndCache implements Closeable { private final KeyAndTimestamp keyAndTimestamp; private final Cache keyCache; @@ -1543,9 +1816,9 @@ public final class TokenService { private KeyAndCache(KeyAndTimestamp keyAndTimestamp, BytesKey salt) { this.keyAndTimestamp = keyAndTimestamp; keyCache = CacheBuilder.builder() - .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) - .setMaximumWeight(500L) - .build(); + .setExpireAfterAccess(TimeValue.timeValueMinutes(60L)) + .setMaximumWeight(500L) + .build(); try { SecretKey secretKey = computeSecretKey(keyAndTimestamp.getKey().getChars(), salt.bytes); keyCache.put(salt, secretKey); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 085df140f3e..795cc9fb225 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -242,7 +242,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true); final UserToken userToken = future.actionGet().v1(); mockGetTokenFromId(userToken, false, client); - final String tokenString = tokenService.getUserTokenString(userToken); + final String tokenString = tokenService.getAccessTokenAsString(userToken); final SamlLogoutRequest request = new SamlLogoutRequest(); request.setToken(tokenString); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index 5eee33711e2..cda0586886c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -1109,7 +1109,7 @@ public class AuthenticationServiceTests extends ESTestCase { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getUserTokenString(tokenFuture.get().v1()); + String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); mockGetTokenFromId(tokenFuture.get().v1(), false, client); when(securityIndex.isAvailable()).thenReturn(true); @@ -1192,7 +1192,7 @@ public class AuthenticationServiceTests extends ESTestCase { Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); } - String token = tokenService.getUserTokenString(tokenFuture.get().v1()); + String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1()); mockGetTokenFromId(tokenFuture.get().v1(), true, client); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[1]).run(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index 61ea4ef9672..7499d8be7d1 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -5,11 +5,14 @@ */ package org.elasticsearch.xpack.security.authc; +import org.apache.directory.api.util.Strings; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.common.settings.SecureString; @@ -23,6 +26,7 @@ import org.elasticsearch.test.SecuritySettingsSource; import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; @@ -38,7 +42,13 @@ import org.junit.Before; import java.time.Instant; import java.time.temporal.ChronoUnit; +import java.util.ArrayList; import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -330,7 +340,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); } - public void testRefreshingMultipleTimes() { + public void testRefreshingMultipleTimesFails() throws Exception { Client client = client().filterWithHeader(Collections.singletonMap("Authorization", UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); @@ -343,12 +353,101 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { assertNotNull(createTokenResponse.getRefreshToken()); CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); assertNotNull(refreshResponse); + // We now have two documents, the original(now refreshed) token doc and the new one with the new access doc + AtomicReference docId = new AtomicReference<>(); + assertBusy(() -> { + SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setSource(SearchSourceBuilder.searchSource() + .query(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("doc_type", "token")) + .must(QueryBuilders.termQuery("refresh_token.refreshed", "true")))) + .setSize(1) + .setTerminateAfter(1) + .get(); + assertThat(searchResponse.getHits().getTotalHits().value, equalTo(1L)); + docId.set(searchResponse.getHits().getAt(0).getId()); + }); + // hack doc to modify the refresh time to 50 seconds ago so that we don't hit the lenient refresh case + Instant refreshed = Instant.now(); + Instant aWhileAgo = refreshed.minus(50L, ChronoUnit.SECONDS); + assertTrue(Instant.now().isAfter(aWhileAgo)); + UpdateResponse updateResponse = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get()) + .setDoc("refresh_token", Collections.singletonMap("refresh_time", aWhileAgo.toEpochMilli())) + .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) + .setFetchSource("refresh_token", Strings.EMPTY_STRING) + .get(); + assertNotNull(updateResponse); + Map refreshTokenMap = (Map) updateResponse.getGetResult().sourceAsMap().get("refresh_token"); + assertTrue( + Instant.ofEpochMilli((long) refreshTokenMap.get("refresh_time")).isBefore(Instant.now().minus(30L, ChronoUnit.SECONDS))); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); assertEquals("invalid_grant", e.getMessage()); assertEquals(RestStatus.BAD_REQUEST, e.status()); - assertEquals("token has already been refreshed", e.getHeader("error_description").get(0)); + assertEquals("token has already been refreshed more than 30 seconds in the past", e.getHeader("error_description").get(0)); + } + + public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + Set refreshTokens = new HashSet<>(); + Set accessTokens = new HashSet<>(); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + final int numberOfProcessors = Runtime.getRuntime().availableProcessors(); + final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3); + List threads = new ArrayList<>(numberOfThreads); + final CountDownLatch readyLatch = new CountDownLatch(numberOfThreads + 1); + final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads); + AtomicBoolean failed = new AtomicBoolean(); + for (int i = 0; i < numberOfThreads; i++) { + threads.add(new Thread(() -> { + // Each thread gets its own client so that more than one nodes will be hit + Client threadClient = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient threadSecurityClient = new SecurityClient(threadClient); + CreateTokenRequest refreshRequest = + threadSecurityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).request(); + readyLatch.countDown(); + try { + readyLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + completedLatch.countDown(); + return; + } + threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> { + accessTokens.add(result.getTokenString()); + refreshTokens.add(result.getRefreshToken()); + logger.info("received access token [{}] and refresh token [{}]", result.getTokenString(), result.getRefreshToken()); + completedLatch.countDown(); + }, e -> { + failed.set(true); + completedLatch.countDown(); + logger.error("caught exception", e); + })); + })); + } + for (Thread thread : threads) { + thread.start(); + } + readyLatch.countDown(); + readyLatch.await(); + for (Thread thread : threads) { + thread.join(); + } + completedLatch.await(); + assertThat(failed.get(), equalTo(false)); + assertThat(accessTokens.size(), equalTo(1)); + assertThat(refreshTokens.size(), equalTo(1)); } public void testRefreshAsDifferentUser() { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 8caf82e8648..7efb4b51632 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.xpack.security.authc; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.get.GetAction; @@ -23,6 +24,8 @@ import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Tuple; +import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; @@ -51,7 +54,11 @@ import org.junit.AfterClass; import org.junit.Before; import org.junit.BeforeClass; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.security.GeneralSecurityException; import java.time.Clock; import java.time.Instant; import java.time.temporal.ChronoUnit; @@ -61,6 +68,7 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Consumer; +import javax.crypto.CipherOutputStream; import javax.crypto.SecretKey; import static java.time.Clock.systemUTC; @@ -151,7 +159,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -198,7 +206,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -219,10 +227,10 @@ public class TokenServiceTests extends ESTestCase { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); + assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { @@ -247,7 +255,7 @@ public class TokenServiceTests extends ESTestCase { rotateKeys(tokenService); } TokenService otherTokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, - clusterService); + clusterService); otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); PlainActionFuture> tokenFuture = new PlainActionFuture<>(); @@ -258,7 +266,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); otherTokenService.getAndValidateToken(requestContext, future); @@ -289,7 +297,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -316,7 +324,7 @@ public class TokenServiceTests extends ESTestCase { tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); - assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); + assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token)); metaData = tokenService.pruneKeys(1); tokenService.refreshMetaData(metaData); @@ -329,7 +337,7 @@ public class TokenServiceTests extends ESTestCase { } requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken)); mockGetTokenFromId(newToken, false); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -351,7 +359,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -362,8 +370,8 @@ public class TokenServiceTests extends ESTestCase { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // verify a second separate token service with its own passphrase cannot verify - TokenService anotherService = new TokenService(Settings.EMPTY, systemUTC(), client, securityIndex, - clusterService); + TokenService anotherService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex, + clusterService); PlainActionFuture future = new PlainActionFuture<>(); anotherService.getAndValidateToken(requestContext, future); assertNull(future.get()); @@ -377,10 +385,10 @@ public class TokenServiceTests extends ESTestCase { PlainActionFuture> tokenFuture = new PlainActionFuture<>(); tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true); UserToken token = tokenFuture.get().v1(); - assertThat(tokenService.getUserTokenString(token), notNullValue()); + assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); tokenService.clearActiveKeyCache(); - assertThat(tokenService.getUserTokenString(token), notNullValue()); + assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue()); } public void testInvalidatedToken() throws Exception { @@ -395,7 +403,7 @@ public class TokenServiceTests extends ESTestCase { mockGetTokenFromId(token, true); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { PlainActionFuture future = new PlainActionFuture<>(); @@ -449,7 +457,7 @@ public class TokenServiceTests extends ESTestCase { authentication = token.getAuthentication(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { // the clock is still frozen, so the cookie should be valid @@ -559,7 +567,7 @@ public class TokenServiceTests extends ESTestCase { //mockGetTokenFromId(token, false); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); - requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); + requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token)); doAnswer(invocationOnMock -> { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -598,7 +606,7 @@ public class TokenServiceTests extends ESTestCase { Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); mockGetTokenFromId(expired, false); - String userTokenString = tokenService.getUserTokenString(expired); + String userTokenString = tokenService.getAccessTokenAsString(expired); PlainActionFuture>> authFuture = new PlainActionFuture<>(); tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); Authentication retrievedAuth = authFuture.actionGet().v1(); @@ -639,4 +647,28 @@ public class TokenServiceTests extends ESTestCase { assertEquals(expected.getMetadata(), result.getMetadata()); assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); } + + protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException, + GeneralSecurityException { + try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES); + OutputStream base64 = Base64.getEncoder().wrap(os); + StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(Version.V_7_0_0); + TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache(); + Version.writeVersion(Version.V_7_0_0, out); + out.writeByteArray(keyAndCache.getSalt().bytes); + out.writeByteArray(keyAndCache.getKeyHash().bytes); + final byte[] initializationVector = tokenService.getNewInitializationVector(); + out.writeByteArray(initializationVector); + try (CipherOutputStream encryptedOutput = + new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0)); + StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(Version.V_7_0_0); + encryptedStreamOutput.writeString(userToken.getId()); + encryptedStreamOutput.close(); + return new String(os.toByteArray(), StandardCharsets.UTF_8); + } + } + } + } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java index f180e356b76..55ae297ae4e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java @@ -25,8 +25,7 @@ public class TokensInvalidationResultTests extends ESTestCase { TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Arrays.asList("token3", "token4"), Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), - new ElasticsearchException("boo", new IllegalStateException("far"))), - randomIntBetween(0, 5)); + new ElasticsearchException("boo", new IllegalStateException("far")))); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -56,9 +55,8 @@ public class TokensInvalidationResultTests extends ESTestCase { } public void testToXcontentWithNoErrors() throws Exception{ - TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), - Collections.emptyList(), - Collections.emptyList(), randomIntBetween(0, 5)); + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Collections.emptyList(), + Collections.emptyList()); try (XContentBuilder builder = JsonXContent.contentBuilder()) { result.toXContent(builder, ToXContent.EMPTY_PARAMS); assertThat(Strings.toString(builder),