Support concurrent refresh of refresh tokens (#39647)

This is a backport of #39631

Co-authored-by: Jay Modi jaymode@users.noreply.github.com

This change adds support for the concurrent refresh of access
tokens as described in #36872
In short it allows subsequent client requests to refresh the same token that
come within a predefined window of 60 seconds to be handled as duplicates
of the original one and thus receive the same response with the same newly
issued access token and refresh token.
In order to support that, two new fields are added in the token document. One
contains the instant (in epoqueMillis) when a given refresh token is refreshed
and one that contains a pointer to the token document that stores the new
refresh token and access token that was created by the original refresh.
A side effect of this change, that was however also a intended enhancement
for the token service, is that we needed to stop encrypting the string
representation of the UserToken while serializing. ( It was necessary as we
correctly used a new IV for every time we encrypted a token in serialization, so
subsequent serializations of the same exact UserToken would produce
different access token strings)

This change also handles the serialization/deserialization BWC logic:

    In mixed clusters we keep creating tokens in the old format and
    consume only old format tokens
    In upgraded clusters, we start creating tokens in the new format but
    still remain able to consume old format tokens (that could have been
    created during the rolling upgrade and are still valid)
    When reading/writing TokensInvalidationResult objects, we take into
    consideration that pre 7.1.0 these contained an integer field that carried
    the attempt count

Resolves #36872
This commit is contained in:
Ioannis Kakavas 2019-03-05 14:55:59 +02:00 committed by GitHub
parent e8d9744340
commit 7ed9d52824
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 710 additions and 305 deletions

View File

@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.security.authc.support; package org.elasticsearch.xpack.core.security.authc.support;
import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.StreamOutput;
@ -32,10 +33,9 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
private final List<String> invalidatedTokens; private final List<String> invalidatedTokens;
private final List<String> previouslyInvalidatedTokens; private final List<String> previouslyInvalidatedTokens;
private final List<ElasticsearchException> errors; private final List<ElasticsearchException> errors;
private final int attemptCount;
public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens, public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens,
@Nullable List<ElasticsearchException> errors, int attemptCount) { @Nullable List<ElasticsearchException> errors) {
Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided");
this.invalidatedTokens = invalidatedTokens; this.invalidatedTokens = invalidatedTokens;
Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided");
@ -45,18 +45,19 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
} else { } else {
this.errors = Collections.emptyList(); this.errors = Collections.emptyList();
} }
this.attemptCount = attemptCount;
} }
public TokensInvalidationResult(StreamInput in) throws IOException { public TokensInvalidationResult(StreamInput in) throws IOException {
this.invalidatedTokens = in.readStringList(); this.invalidatedTokens = in.readStringList();
this.previouslyInvalidatedTokens = in.readStringList(); this.previouslyInvalidatedTokens = in.readStringList();
this.errors = in.readList(StreamInput::readException); this.errors = in.readList(StreamInput::readException);
this.attemptCount = in.readVInt(); if (in.getVersion().before(Version.V_7_1_0)) {
in.readVInt();
}
} }
public static TokensInvalidationResult emptyResult() { public static TokensInvalidationResult emptyResult() {
return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList());
} }
@ -72,10 +73,6 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
return errors; return errors;
} }
public int getAttemptCount() {
return attemptCount;
}
@Override @Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject() builder.startObject()
@ -100,6 +97,8 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
out.writeStringCollection(invalidatedTokens); out.writeStringCollection(invalidatedTokens);
out.writeStringCollection(previouslyInvalidatedTokens); out.writeStringCollection(previouslyInvalidatedTokens);
out.writeCollection(errors, StreamOutput::writeException); out.writeCollection(errors, StreamOutput::writeException);
out.writeVInt(attemptCount); if (out.getVersion().before(Version.V_7_1_0)) {
out.writeVInt(5);
}
} }
} }

View File

@ -199,6 +199,13 @@
"refreshed" : { "refreshed" : {
"type" : "boolean" "type" : "boolean"
}, },
"refresh_time": {
"type": "date",
"format": "epoch_millis"
},
"superseded_by": {
"type": "keyword"
},
"invalidated" : { "invalidated" : {
"type" : "boolean" "type" : "boolean"
}, },

View File

@ -29,8 +29,7 @@ public class InvalidateTokenResponseTests extends ESTestCase {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))));
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result); InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) { try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output); response.writeTo(output);
@ -47,8 +46,7 @@ public class InvalidateTokenResponseTests extends ESTestCase {
} }
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)), Arrays.asList(generateRandomStringArray(20, 15, false)), Collections.emptyList());
Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result); response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) { try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output); response.writeTo(output);
@ -68,8 +66,7 @@ public class InvalidateTokenResponseTests extends ESTestCase {
List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens,
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))));
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result); InvalidateTokenResponse response = new InvalidateTokenResponse(result);
XContentBuilder builder = XContentFactory.jsonBuilder(); XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS); response.toXContent(builder, ToXContent.EMPTY_PARAMS);

View File

@ -63,7 +63,7 @@ public final class TransportSamlAuthenticateAction extends HandledTransportActio
final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA); final Map<String, Object> tokenMeta = (Map<String, Object>) result.getMetadata().get(SamlRealm.CONTEXT_TOKEN_DATA);
tokenService.createUserToken(authentication, originatingAuthentication, tokenService.createUserToken(authentication, originatingAuthentication,
ActionListener.wrap(tuple -> { ActionListener.wrap(tuple -> {
final String tokenString = tokenService.getUserTokenString(tuple.v1()); final String tokenString = tokenService.getAccessTokenAsString(tuple.v1());
final TimeValue expiresIn = tokenService.getExpirationDelay(); final TimeValue expiresIn = tokenService.getExpirationDelay();
listener.onResponse( listener.onResponse(
new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn)); new SamlAuthenticateResponse(authentication.getUser().principal(), tokenString, tuple.v2(), expiresIn));

View File

@ -89,7 +89,7 @@ public final class TransportCreateTokenAction extends HandledTransportAction<Cre
boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) { boolean includeRefreshToken, ActionListener<CreateTokenResponse> listener) {
try { try {
tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> { tokenService.createUserToken(authentication, originatingAuth, ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getUserTokenString(tuple.v1()); final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope()); final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response = final CreateTokenResponse response =

View File

@ -31,7 +31,7 @@ public class TransportRefreshTokenAction extends HandledTransportAction<CreateTo
@Override @Override
protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) { protected void doExecute(Task task, CreateTokenRequest request, ActionListener<CreateTokenResponse> listener) {
tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> {
final String tokenStr = tokenService.getUserTokenString(tuple.v1()); final String tokenStr = tokenService.getAccessTokenAsString(tuple.v1());
final String scope = getResponseScopeValue(request.getScope()); final String scope = getResponseScopeValue(request.getScope());
final CreateTokenResponse response = final CreateTokenResponse response =

View File

@ -19,6 +19,7 @@ import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest.OpType; import org.elasticsearch.action.DocWriteRequest.OpType;
import org.elasticsearch.action.DocWriteResponse; import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.bulk.BackoffPolicy;
import org.elasticsearch.action.bulk.BulkItemResponse; import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequestBuilder; import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.bulk.BulkResponse;
@ -28,6 +29,7 @@ import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.TransportActions;
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.master.AcknowledgedRequest;
import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequest;
@ -67,6 +69,7 @@ import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.seqno.SequenceNumbers;
import org.elasticsearch.rest.RestStatus; import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHit;
import org.elasticsearch.xpack.core.XPackField; import org.elasticsearch.xpack.core.XPackField;
@ -113,12 +116,12 @@ import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Predicate; import java.util.function.Predicate;
@ -129,6 +132,7 @@ import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK
import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING; import static org.elasticsearch.search.SearchService.DEFAULT_KEEPALIVE_SETTING;
import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
import static org.elasticsearch.threadpool.ThreadPool.Names.GENERIC;
/** /**
* Service responsible for the creation, validation, and other management of {@link UserToken} * Service responsible for the creation, validation, and other management of {@link UserToken}
@ -155,6 +159,7 @@ public final class TokenService {
private static final String MALFORMED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY + private static final String MALFORMED_TOKEN_WWW_AUTH_VALUE = "Bearer realm=\"" + XPackField.SECURITY +
"\", error=\"invalid_token\", error_description=\"The access token is malformed\""; "\", error=\"invalid_token\", error_description=\"The access token is malformed\"";
private static final String TYPE = "doc"; private static final String TYPE = "doc";
private static final BackoffPolicy DEFAULT_BACKOFF = BackoffPolicy.exponentialBackoff();
public static final String THREAD_POOL_NAME = XPackField.SECURITY + "-token-key"; public static final String THREAD_POOL_NAME = XPackField.SECURITY + "-token-key";
public static final Setting<TimeValue> TOKEN_EXPIRATION = Setting.timeSetting("xpack.security.authc.token.timeout", public static final Setting<TimeValue> TOKEN_EXPIRATION = Setting.timeSetting("xpack.security.authc.token.timeout",
@ -167,8 +172,7 @@ public final class TokenService {
private static final String TOKEN_DOC_TYPE = "token"; private static final String TOKEN_DOC_TYPE = "token";
private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_"; private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_";
static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1;
private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue();
private static final int MAX_RETRY_ATTEMPTS = 5;
private static final Logger logger = LogManager.getLogger(TokenService.class); private static final Logger logger = LogManager.getLogger(TokenService.class);
private final SecureRandom secureRandom = new SecureRandom(); private final SecureRandom secureRandom = new SecureRandom();
@ -221,12 +225,22 @@ public final class TokenService {
} }
/** /**
* Create a token based on the provided authentication and metadata. * Creates a token based on the provided authentication and metadata with an auto-generated token id.
* The created token will be stored in the security index. * The created token will be stored in the security index.
*/ */
public void createUserToken(Authentication authentication, Authentication originatingClientAuth, public void createUserToken(Authentication authentication, Authentication originatingClientAuth,
ActionListener<Tuple<UserToken, String>> listener, Map<String, Object> metadata, ActionListener<Tuple<UserToken, String>> listener, Map<String, Object> metadata,
boolean includeRefreshToken) throws IOException { boolean includeRefreshToken) throws IOException {
createUserToken(UUIDs.randomBase64UUID(), authentication, originatingClientAuth, listener, metadata, includeRefreshToken);
}
/**
* Create a token based on the provided authentication and metadata with the given token id.
* The created token will be stored in the security index.
*/
private void createUserToken(String userTokenId, Authentication authentication, Authentication originatingClientAuth,
ActionListener<Tuple<UserToken, String>> listener, Map<String, Object> metadata,
boolean includeRefreshToken) throws IOException {
ensureEnabled(); ensureEnabled();
if (authentication == null) { if (authentication == null) {
listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided")));
@ -239,7 +253,7 @@ public final class TokenService {
final Version version = clusterService.state().nodes().getMinNodeVersion(); final Version version = clusterService.state().nodes().getMinNodeVersion();
final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), final Authentication tokenAuth = new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(),
authentication.getLookedUpBy(), version, AuthenticationType.TOKEN, authentication.getMetadata()); authentication.getLookedUpBy(), version, AuthenticationType.TOKEN, authentication.getMetadata());
final UserToken userToken = new UserToken(version, tokenAuth, expiration, metadata); final UserToken userToken = new UserToken(userTokenId, version, tokenAuth, expiration, metadata);
final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null; final String refreshToken = includeRefreshToken ? UUIDs.randomBase64UUID() : null;
try (XContentBuilder builder = XContentFactory.jsonBuilder()) { try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
@ -280,9 +294,33 @@ public final class TokenService {
} }
} }
/**
* Reconstructs the {@link UserToken} from the existing {@code userTokenSource} and call the listener with the {@link UserToken} and the
* refresh token string
*/
private void reIssueTokens(Map<String, Object> userTokenSource,
String refreshToken, ActionListener<Tuple<UserToken, String>> listener) {
final String authString = (String) userTokenSource.get("authentication");
final Integer version = (Integer) userTokenSource.get("version");
final Map<String, Object> metadata = (Map<String, Object>) userTokenSource.get("metadata");
final String id = (String) userTokenSource.get("id");
final Long expiration = (Long) userTokenSource.get("expiration_time");
Version authVersion = Version.fromId(version);
try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) {
in.setVersion(authVersion);
Authentication authentication = new Authentication(in);
UserToken userToken = new UserToken(id, authVersion, authentication, Instant.ofEpochMilli(expiration), metadata);
listener.onResponse(new Tuple<>(userToken, refreshToken));
} catch (IOException e) {
logger.error("Unable to decode existing user token", e);
listener.onFailure(invalidGrantException("could not refresh the requested token"));
}
}
/** /**
* Looks in the context to see if the request provided a header with a user token and if so the * Looks in the context to see if the request provided a header with a user token and if so the
* token is validated, which includes authenticated decryption and verification that the token * token is validated, which might include authenticated decryption and verification that the token
* has not been revoked or is expired. * has not been revoked or is expired.
*/ */
void getAndValidateToken(ThreadContext ctx, ActionListener<UserToken> listener) { void getAndValidateToken(ThreadContext ctx, ActionListener<UserToken> listener) {
@ -329,42 +367,19 @@ public final class TokenService {
} }
/** /**
* Asynchronously decodes the string representation of a {@link UserToken}. The process for * Gets the UserToken with given id by fetching the the corresponding token document
* this is asynchronous as we may need to compute a key, which can be computationally expensive
* so this should not block the current thread, which is typically a network thread. A second
* reason for being asynchronous is that we can restrain the amount of resources consumed by
* the key computation to a single thread.
*/ */
void decodeToken(String token, ActionListener<UserToken> listener) throws IOException { void getUserTokenFromId(String userTokenId, ActionListener<UserToken> listener) {
// We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key!
byte[] bytes = token.getBytes(StandardCharsets.UTF_8);
StreamInput in = new InputStreamStreamInput(Base64.getDecoder().wrap(new ByteArrayInputStream(bytes)), bytes.length);
if (in.available() < MINIMUM_BASE64_BYTES) {
logger.debug("invalid token");
listener.onResponse(null);
} else {
// the token exists and the value is at least as long as we'd expect
final Version version = Version.readVersion(in);
in.setVersion(version);
final BytesKey decodedSalt = new BytesKey(in.readByteArray());
final BytesKey passphraseHash = new BytesKey(in.readByteArray());
KeyAndCache keyAndCache = keyCache.get(passphraseHash);
if (keyAndCache != null) {
getKeyAsync(decodedSalt, keyAndCache, ActionListener.wrap(decodeKey -> {
try {
final byte[] iv = in.readByteArray();
final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt);
decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> {
if (securityIndex.isAvailable() == false) { if (securityIndex.isAvailable() == false) {
logger.warn("failed to get token [{}] since index is not available", tokenId); logger.warn("failed to get token [{}] since index is not available", userTokenId);
listener.onResponse(null); listener.onResponse(null);
} else { } else {
securityIndex.checkIndexVersionThenExecute( securityIndex.checkIndexVersionThenExecute(
ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)), ex -> listener.onFailure(traceLog("prepare security index", userTokenId, ex)),
() -> { () -> {
final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE,
getTokenDocumentId(tokenId)).request(); getTokenDocumentId(userTokenId)).request();
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex)); Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("decode token", userTokenId, ex));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest,
ActionListener.<GetResponse>wrap(response -> { ActionListener.<GetResponse>wrap(response -> {
if (response.isExists()) { if (response.isExists()) {
@ -389,16 +404,51 @@ public final class TokenService {
// if the index or the shard is not there / available we assume that // if the index or the shard is not there / available we assume that
// the token is not valid // the token is not valid
if (isShardNotAvailableException(e)) { if (isShardNotAvailableException(e)) {
logger.warn("failed to get token [{}] since index is not available", tokenId); logger.warn("failed to get token [{}] since index is not available", userTokenId);
listener.onResponse(null); listener.onResponse(null);
} else { } else {
logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e); logger.error(new ParameterizedMessage("failed to get token [{}]", userTokenId), e);
listener.onFailure(e); listener.onFailure(e);
} }
}), client::get); }), client::get);
}); });
} }
}, listener::onFailure)); }
/*
* If needed, for tokens that were created in a pre 7.1.0 cluster, it asynchronously decodes the token to get the token document Id.
* The process for this is asynchronous as we may need to compute a key, which can be computationally expensive
* so this should not block the current thread, which is typically a network thread. A second reason for being asynchronous is that
* we can restrain the amount of resources consumed by the key computation to a single thread.
* For tokens created in an after 7.1.0 cluster, the token is just the token document Id so this is used directly without decryption
*/
void decodeToken(String token, ActionListener<UserToken> listener) throws IOException {
// We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key!
byte[] bytes = token.getBytes(StandardCharsets.UTF_8);
StreamInput in = new InputStreamStreamInput(Base64.getDecoder().wrap(new ByteArrayInputStream(bytes)), bytes.length);
final Version version = Version.readVersion(in);
if (version.onOrAfter(Version.V_7_1_0)) {
// The token was created in a > 7.1.0 cluster so it contains the tokenId as a String
String usedTokenId = in.readString();
getUserTokenFromId(usedTokenId, listener);
} else {
// The token was created in a < 7.1.0 cluster so we need to decrypt it to get the tokenId
in.setVersion(version);
if (in.available() < MINIMUM_BASE64_BYTES) {
logger.debug("invalid token, smaller than [{}] bytes", MINIMUM_BASE64_BYTES);
listener.onResponse(null);
return;
}
final BytesKey decodedSalt = new BytesKey(in.readByteArray());
final BytesKey passphraseHash = new BytesKey(in.readByteArray());
KeyAndCache keyAndCache = keyCache.get(passphraseHash);
if (keyAndCache != null) {
getKeyAsync(decodedSalt, keyAndCache, ActionListener.wrap(decodeKey -> {
try {
final byte[] iv = in.readByteArray();
final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt);
decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> getUserTokenFromId(tokenId, listener),
listener::onFailure));
} catch (GeneralSecurityException e) { } catch (GeneralSecurityException e) {
// could happen with a token that is not ours // could happen with a token that is not ours
logger.warn("invalid token", e); logger.warn("invalid token", e);
@ -442,8 +492,8 @@ public final class TokenService {
/** /**
* This method performs the steps necessary to invalidate a token so that it may no longer be * This method performs the steps necessary to invalidate a token so that it may no longer be
* used. The process of invalidation involves performing an update to * used. The process of invalidation involves performing an update to the token document and setting
* the token document and setting the <code>invalidated</code> field to <code>true</code> * the <code>invalidated</code> field to <code>true</code>
*/ */
public void invalidateAccessToken(String tokenString, ActionListener<TokensInvalidationResult> listener) { public void invalidateAccessToken(String tokenString, ActionListener<TokensInvalidationResult> listener) {
ensureEnabled(); ensureEnabled();
@ -452,12 +502,13 @@ public final class TokenService {
listener.onFailure(new IllegalArgumentException("token must be provided")); listener.onFailure(new IllegalArgumentException("token must be provided"));
} else { } else {
maybeStartTokenRemover(); maybeStartTokenRemover();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
try { try {
decodeToken(tokenString, ActionListener.wrap(userToken -> { decodeToken(tokenString, ActionListener.wrap(userToken -> {
if (userToken == null) { if (userToken == null) {
listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException())); listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException()));
} else { } else {
indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff,
"access_token", null); "access_token", null);
} }
}, listener::onFailure)); }, listener::onFailure));
@ -480,12 +531,14 @@ public final class TokenService {
listener.onFailure(new IllegalArgumentException("token must be provided")); listener.onFailure(new IllegalArgumentException("token must be provided"));
} else { } else {
maybeStartTokenRemover(); maybeStartTokenRemover();
indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null); final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
indexInvalidation(Collections.singleton(userToken.getId()), listener, backoff, "access_token", null);
} }
} }
/** /**
* This method performs the steps necessary to invalidate a refresh token so that it may no longer be used. * This method onvalidates a refresh token so that it may no longer be used. Iinvalidation involves performing an update to the token
* document and setting the <code>refresh_token.invalidated</code> field to <code>true</code>
* *
* @param refreshToken The string representation of the refresh token * @param refreshToken The string representation of the refresh token
* @param listener the listener to notify upon completion * @param listener the listener to notify upon completion
@ -497,16 +550,17 @@ public final class TokenService {
listener.onFailure(new IllegalArgumentException("refresh token must be provided")); listener.onFailure(new IllegalArgumentException("refresh token must be provided"));
} else { } else {
maybeStartTokenRemover(); maybeStartTokenRemover();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
findTokenFromRefreshToken(refreshToken, findTokenFromRefreshToken(refreshToken,
ActionListener.wrap(tuple -> { ActionListener.wrap(searchResponse -> {
final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId()); final String docId = getTokenIdFromDocumentId(searchResponse.getHits().getAt(0).getId());
indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null); indexInvalidation(Collections.singletonList(docId), listener, backoff, "refresh_token", null);
}, listener::onFailure), new AtomicInteger(0)); }, listener::onFailure), backoff);
} }
} }
/** /**
* Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given * Invalidates all access tokens and all refresh tokens of a given {@code realmName} and/or of a given
* {@code username} so that they may no longer be used * {@code username} so that they may no longer be used
* *
* @param realmName the realm of which the tokens should be invalidated * @param realmName the realm of which the tokens should be invalidated
@ -557,32 +611,30 @@ public final class TokenService {
maybeStartTokenRemover(); maybeStartTokenRemover();
// Invalidate the refresh tokens first so that they cannot be used to get new // Invalidate the refresh tokens first so that they cannot be used to get new
// access tokens while we invalidate the access tokens we currently know about // access tokens while we invalidate the access tokens we currently know about
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
indexInvalidation(accessTokenIds, ActionListener.wrap(result -> indexInvalidation(accessTokenIds, ActionListener.wrap(result ->
indexInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()), indexInvalidation(accessTokenIds, listener, backoff, "access_token", result),
"access_token", result), listener::onFailure), backoff, "refresh_token", null);
listener::onFailure), new AtomicInteger(0), "refresh_token", null);
} }
/** /**
* Performs the actual invalidation of a collection of tokens * Performs the actual invalidation of a collection of tokens. In case of recoverable errors ( see
* {@link TransportActions#isShardNotAvailableException} ) the UpdateRequests to mark the tokens as invalidated are retried using
* an exponential backoff policy.
* *
* @param tokenIds the tokens to invalidate * @param tokenIds the tokens to invalidate
* @param listener the listener to notify upon completion * @param listener the listener to notify upon completion
* @param attemptCount the number of attempts to invalidate that have already been tried * @param backoff the amount of time to delay between attempts
* @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on * @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on
* what type of tokens should be invalidated * what type of tokens should be invalidated
* @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating
* tokens up to the point of the retry. This result is added to the result of the current attempt * tokens up to the point of the retry. This result is added to the result of the current attempt
*/ */
private void indexInvalidation(Collection<String> tokenIds, ActionListener<TokensInvalidationResult> listener, private void indexInvalidation(Collection<String> tokenIds, ActionListener<TokensInvalidationResult> listener,
AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { Iterator<TimeValue> backoff, String srcPrefix, @Nullable TokensInvalidationResult previousResult) {
if (tokenIds.isEmpty()) { if (tokenIds.isEmpty()) {
logger.warn("No [{}] tokens provided for invalidation", srcPrefix); logger.warn("No [{}] tokens provided for invalidation", srcPrefix);
listener.onFailure(invalidGrantException("No tokens provided for invalidation")); listener.onFailure(invalidGrantException("No tokens provided for invalidation"));
} else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(),
attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate tokens"));
} else { } else {
BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); BulkRequestBuilder bulkRequestBuilder = client.prepareBulk();
for (String tokenId : tokenIds) { for (String tokenId : tokenIds) {
@ -627,20 +679,30 @@ public final class TokenService {
} }
} }
if (retryTokenDocIds.isEmpty() == false) { if (retryTokenDocIds.isEmpty() == false) {
if (backoff.hasNext()) {
logger.debug("failed to invalidate [{}] tokens out of [{}], retrying to invalidate these too",
retryTokenDocIds.size(), tokenIds.size());
TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated,
failedRequestResponses, attemptCount.get()); failedRequestResponses);
attemptCount.incrementAndGet(); client.threadPool().schedule(
indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult); () -> indexInvalidation(retryTokenDocIds, listener, backoff, srcPrefix, incompleteResult),
backoff.next(), GENERIC);
} else {
logger.warn("failed to invalidate [{}] tokens out of [{}] after all retries",
retryTokenDocIds.size(), tokenIds.size());
} }
} else {
TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated,
failedRequestResponses, attemptCount.get()); failedRequestResponses);
listener.onResponse(result); listener.onResponse(result);
}
}, e -> { }, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e); Throwable cause = ExceptionsHelper.unwrapCause(e);
traceLog("invalidate tokens", cause); traceLog("invalidate tokens", cause);
if (isShardNotAvailableException(cause)) { if (isShardNotAvailableException(cause) && backoff.hasNext()) {
attemptCount.incrementAndGet(); logger.debug("failed to invalidate tokens, retrying ");
indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult); client.threadPool().schedule(
() -> indexInvalidation(tokenIds, listener, backoff, srcPrefix, previousResult), backoff.next(), GENERIC);
} else { } else {
listener.onFailure(e); listener.onFailure(e);
} }
@ -649,31 +711,34 @@ public final class TokenService {
} }
/** /**
* Uses the refresh token to refresh its associated token and returns the new token with an * Called by the transport action in order to start the process of refreshing a token.
* updated expiration date to the listener
*/ */
public void refreshToken(String refreshToken, ActionListener<Tuple<UserToken, String>> listener) { public void refreshToken(String refreshToken, ActionListener<Tuple<UserToken, String>> listener) {
ensureEnabled(); ensureEnabled();
final Instant refreshRequested = clock.instant();
final Iterator<TimeValue> backoff = DEFAULT_BACKOFF.iterator();
findTokenFromRefreshToken(refreshToken, findTokenFromRefreshToken(refreshToken,
ActionListener.wrap(tuple -> { ActionListener.wrap(searchResponse -> {
final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); final Authentication clientAuth = Authentication.readFromContext(client.threadPool().getThreadContext());
final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); final SearchHit tokenDocHit = searchResponse.getHits().getHits()[0];
innerRefresh(tokenDocId, clientAuth, listener, tuple.v2()); final String tokenDocId = tokenDocHit.getId();
innerRefresh(tokenDocId, tokenDocHit.getSourceAsMap(), tokenDocHit.getSeqNo(), tokenDocHit.getPrimaryTerm(), clientAuth,
listener, backoff, refreshRequested);
}, listener::onFailure), }, listener::onFailure),
new AtomicInteger(0)); backoff);
} }
private void findTokenFromRefreshToken(String refreshToken, ActionListener<Tuple<SearchResponse, AtomicInteger>> listener, /**
AtomicInteger attemptCount) { * Performs an asynchronous search request for the token document that contains the {@code refreshToken} and calls the listener with the
if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { * {@link SearchResponse}. In case of recoverable errors the SearchRequest is retried using an exponential backoff policy.
logger.warn("Failed to find token for refresh token [{}] after [{}] attempts", refreshToken, attemptCount.get()); */
listener.onFailure(invalidGrantException("could not refresh the requested token")); private void findTokenFromRefreshToken(String refreshToken, ActionListener<SearchResponse> listener,
} else { Iterator<TimeValue> backoff) {
SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setQuery(QueryBuilders.boolQuery() .setQuery(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE)) .filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
.filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
.setVersion(true) .seqNoAndPrimaryTerm(true)
.request(); .request();
final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
@ -682,29 +747,39 @@ public final class TokenService {
listener.onFailure(invalidGrantException("could not refresh the requested token")); listener.onFailure(invalidGrantException("could not refresh the requested token"));
} else if (frozenSecurityIndex.isAvailable() == false) { } else if (frozenSecurityIndex.isAvailable() == false) {
logger.debug("security index is not available to find token from refresh token, retrying"); logger.debug("security index is not available to find token from refresh token, retrying");
attemptCount.incrementAndGet(); client.threadPool().scheduleWithFixedDelay(
findTokenFromRefreshToken(refreshToken, listener, attemptCount); () -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
} else { } else {
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex)); Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("find by refresh token", refreshToken, ex));
securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> securityIndex.checkIndexVersionThenExecute(listener::onFailure, () ->
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
ActionListener.<SearchResponse>wrap(searchResponse -> { ActionListener.<SearchResponse>wrap(searchResponse -> {
if (searchResponse.isTimedOut()) { if (searchResponse.isTimedOut()) {
attemptCount.incrementAndGet(); if (backoff.hasNext()) {
findTokenFromRefreshToken(refreshToken, listener, attemptCount); client.threadPool().scheduleWithFixedDelay(
() -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
} else {
logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else if (searchResponse.getHits().getHits().length < 1) { } else if (searchResponse.getHits().getHits().length < 1) {
logger.info("could not find token document with refresh_token [{}]", refreshToken); logger.warn("could not find token document with refresh_token [{}]", refreshToken);
onFailure.accept(invalidGrantException("could not refresh the requested token")); onFailure.accept(invalidGrantException("could not refresh the requested token"));
} else if (searchResponse.getHits().getHits().length > 1) { } else if (searchResponse.getHits().getHits().length > 1) {
onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token")); onFailure.accept(new IllegalStateException("multiple tokens share the same refresh token"));
} else { } else {
listener.onResponse(new Tuple<>(searchResponse, attemptCount)); listener.onResponse(searchResponse);
} }
}, e -> { }, e -> {
if (isShardNotAvailableException(e)) { if (isShardNotAvailableException(e)) {
logger.debug("failed to search for token document, retrying", e); if (backoff.hasNext()) {
attemptCount.incrementAndGet(); logger.debug("failed to find token for refresh token [{}], retrying", refreshToken);
findTokenFromRefreshToken(refreshToken, listener, attemptCount); client.threadPool().scheduleWithFixedDelay(
() -> findTokenFromRefreshToken(refreshToken, listener, backoff), backoff.next(), GENERIC);
} else {
logger.warn("could not find token document with refresh_token [{}] after all retries", refreshToken);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else { } else {
onFailure.accept(e); onFailure.accept(e);
} }
@ -712,55 +787,178 @@ public final class TokenService {
client::search)); client::search));
} }
} }
}
/** /**
* Performs the actual refresh of the token with retries in case of certain exceptions that * Performs the actual refresh of the token with retries in case of certain exceptions that may be recoverable. The
* may be recoverable. The refresh involves retrieval of the token document and then * refresh involves two steps:
* updating the token document to indicate that the document has been refreshed. * First, we check if the token document is still valid for refresh ({@link TokenService#checkTokenDocForRefresh(Map, Authentication)}
* Then, in the case that the token has been refreshed within the previous 30 seconds (see
* {@link TokenService#checkLenientlyIfTokenAlreadyRefreshed(Map, Authentication)}), we do not create a new token document
* but instead retrieve the one that was created by the original refresh and return a user token and
* refresh token based on that ( see {@link TokenService#reIssueTokens(Map, String, ActionListener)} ).
* Otherwise this token document gets its refresh_token marked as refreshed, while also storing the Instant when it was
* refreshed along with a pointer to the new token document that holds the refresh_token that supersedes this one. The new
* document that contains the new access token and refresh token is created and finally the new access token and refresh token are
* returned to the listener.
*/ */
private void innerRefresh(String tokenDocId, Authentication clientAuth, ActionListener<Tuple<UserToken, String>> listener, private void innerRefresh(String tokenDocId, Map<String, Object> source, long seqNo, long primaryTerm, Authentication clientAuth,
AtomicInteger attemptCount) { ActionListener<Tuple<UserToken, String>> listener, Iterator<TimeValue> backoff, Instant refreshRequested) {
if (attemptCount.getAndIncrement() > MAX_RETRY_ATTEMPTS) { logger.debug("Attempting to refresh token [{}]", tokenDocId);
logger.warn("Failed to refresh token for doc [{}] after [{}] attempts", tokenDocId, attemptCount.get());
listener.onFailure(invalidGrantException("could not refresh the requested token"));
} else {
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex)); Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("refresh token", tokenDocId, ex));
GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request();
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest,
ActionListener.<GetResponse>wrap(response -> {
if (response.isExists()) {
final Map<String, Object> source = response.getSource();
final Optional<ElasticsearchSecurityException> invalidSource = checkTokenDocForRefresh(source, clientAuth); final Optional<ElasticsearchSecurityException> invalidSource = checkTokenDocForRefresh(source, clientAuth);
if (invalidSource.isPresent()) { if (invalidSource.isPresent()) {
onFailure.accept(invalidSource.get()); onFailure.accept(invalidSource.get());
} else {
if (eligibleForMultiRefresh(source, refreshRequested)) {
final Map<String, Object> refreshTokenSrc = (Map<String, Object>) source.get("refresh_token");
final String supersedingTokenDocId = (String) refreshTokenSrc.get("superseded_by");
logger.debug("Token document [{}] was recently refreshed, attempting to reuse [{}] for returning an " +
"access token and refresh token", tokenDocId, supersedingTokenDocId);
final ActionListener<GetResponse> getSupersedingListener = new ActionListener<GetResponse>() {
@Override
public void onResponse(GetResponse response) {
if (response.isExists()) {
logger.debug("Found superseding token document [{}] ", supersedingTokenDocId);
final Map<String, Object> supersedingTokenSource = response.getSource();
final Map<String, Object> supersedingUserTokenSource = (Map<String, Object>)
((Map<String, Object>) supersedingTokenSource.get("access_token")).get("user_token");
final Map<String, Object> supersedingRefreshTokenSrc =
(Map<String, Object>) supersedingTokenSource.get("refresh_token");
final String supersedingRefreshTokenValue = (String) supersedingRefreshTokenSrc.get("token");
reIssueTokens(supersedingUserTokenSource, supersedingRefreshTokenValue, listener);
} else if (backoff.hasNext()) {
// We retry this since the creation of the superseding token document might already be in flight but not
// yet completed, triggered by a refresh request that came a few milliseconds ago
logger.info("could not find superseding token document [{}] for token document [{}], retrying",
supersedingTokenDocId, tokenDocId);
client.threadPool().schedule(() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC);
} else {
logger.warn("could not find superseding token document [{}] for token document [{}] after all retries",
supersedingTokenDocId, tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
}
@Override
public void onFailure(Exception e) {
if (isShardNotAvailableException(e)) {
if (backoff.hasNext()) {
logger.info("could not find superseding token document [{}] for refresh, retrying", supersedingTokenDocId);
client.threadPool().schedule(
() -> getTokenDocAsync(supersedingTokenDocId, this), backoff.next(), GENERIC);
} else {
logger.warn("could not find token document [{}] for refresh after all retries", supersedingTokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else {
logger.warn("could not find superseding token document [{}] for refresh", supersedingTokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
}
};
getTokenDocAsync(supersedingTokenDocId, getSupersedingListener);
} else { } else {
final Map<String, Object> userTokenSource = (Map<String, Object>) final Map<String, Object> userTokenSource = (Map<String, Object>)
((Map<String, Object>) source.get("access_token")).get("user_token"); ((Map<String, Object>) source.get("access_token")).get("user_token");
final String authString = (String) userTokenSource.get("authentication"); final String authString = (String) userTokenSource.get("authentication");
final Integer version = (Integer) userTokenSource.get("version"); final Integer version = (Integer) userTokenSource.get("version");
final Map<String, Object> metadata = (Map<String, Object>) userTokenSource.get("metadata"); final Map<String, Object> metadata = (Map<String, Object>) userTokenSource.get("metadata");
Version authVersion = Version.fromId(version); Version authVersion = Version.fromId(version);
Authentication authentication;
try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) {
in.setVersion(authVersion); in.setVersion(authVersion);
Authentication authentication = new Authentication(in); authentication = new Authentication(in);
} catch (IOException e) {
logger.error("failed to decode the authentication stored with token document [{}]", tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
return;
}
final String newUserTokenId = UUIDs.randomBase64UUID();
final Instant refreshTime = clock.instant();
Map<String, Object> updateMap = new HashMap<>();
updateMap.put("refreshed", true);
updateMap.put("refresh_time", refreshTime.toEpochMilli());
updateMap.put("superseded_by", getTokenDocumentId(newUserTokenId));
UpdateRequestBuilder updateRequest = UpdateRequestBuilder updateRequest =
client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId)
.setDoc("refresh_token", Collections.singletonMap("refreshed", true)) .setDoc("refresh_token", updateMap)
.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); .setFetchSource(true)
updateRequest.setIfSeqNo(response.getSeqNo()); .setRefreshPolicy(RefreshPolicy.IMMEDIATE);
updateRequest.setIfPrimaryTerm(response.getPrimaryTerm()); assert seqNo != SequenceNumbers.UNASSIGNED_SEQ_NO : "expected an assigned sequence number";
updateRequest.setIfSeqNo(seqNo);
assert primaryTerm != SequenceNumbers.UNASSIGNED_PRIMARY_TERM : "expected an assigned primary term";
updateRequest.setIfPrimaryTerm(primaryTerm);
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(), executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest.request(),
ActionListener.<UpdateResponse>wrap( ActionListener.<UpdateResponse>wrap(
updateResponse -> createUserToken(authentication, clientAuth, listener, metadata, true), updateResponse -> {
e -> { if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
logger.debug("updated the original token document to {}", updateResponse.getGetResult().sourceAsMap());
createUserToken(newUserTokenId, authentication, clientAuth, listener, metadata, true);
} else if (backoff.hasNext()) {
logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying",
tokenDocId, updateResponse.getResult());
client.threadPool().schedule(
() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
refreshRequested),
backoff.next(), GENERIC);
} else {
logger.info("failed to update the original token document [{}] after all retries, " +
"the update result was [{}]. ", tokenDocId, updateResponse.getResult());
listener.onFailure(invalidGrantException("could not refresh the requested token"));
}
}, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e); Throwable cause = ExceptionsHelper.unwrapCause(e);
if (cause instanceof VersionConflictEngineException || if (cause instanceof VersionConflictEngineException) {
isShardNotAvailableException(e)) { //The document has been updated by another thread, get it again.
innerRefresh(tokenDocId, clientAuth, if (backoff.hasNext()) {
listener, attemptCount); logger.debug("version conflict while updating document [{}], attempting to get it again",
tokenDocId);
final ActionListener<GetResponse> getListener = new ActionListener<GetResponse>() {
@Override
public void onResponse(GetResponse response) {
if (response.isExists()) {
innerRefresh(tokenDocId, response.getSource(), response.getSeqNo(),
response.getPrimaryTerm(), clientAuth, listener, backoff, refreshRequested);
} else {
logger.warn("could not find token document [{}] for refresh", tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
}
@Override
public void onFailure(Exception e) {
if (isShardNotAvailableException(e)) {
if (backoff.hasNext()) {
logger.info("could not get token document [{}] for refresh, " +
"retrying", tokenDocId);
client.threadPool().schedule(
() -> getTokenDocAsync(tokenDocId, this), backoff.next(), GENERIC);
} else {
logger.warn("could not get token document [{}] for refresh after all retries",
tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else {
onFailure.accept(e);
}
}
};
getTokenDocAsync(tokenDocId, getListener);
} else {
logger.warn("version conflict while updating document [{}], no retries left", tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else if (isShardNotAvailableException(e)) {
if (backoff.hasNext()) {
logger.debug("failed to update the original token document [{}], retrying", tokenDocId);
client.threadPool().schedule(
() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
refreshRequested),
backoff.next(), GENERIC);
} else {
logger.warn("failed to update the original token document [{}], after all retries", tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
} else { } else {
onFailure.accept(e); onFailure.accept(e);
} }
@ -768,23 +966,17 @@ public final class TokenService {
client::update); client::update);
} }
} }
} else {
logger.info("could not find token document [{}] for refresh", tokenDocId);
onFailure.accept(invalidGrantException("could not refresh the requested token"));
}
}, e -> {
if (isShardNotAvailableException(e)) {
innerRefresh(tokenDocId, clientAuth, listener, attemptCount);
} else {
listener.onFailure(e);
}
}), client::get);
} }
private void getTokenDocAsync(String tokenDocId, ActionListener<GetResponse> listener) {
GetRequest getRequest =
client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request();
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, listener, client::get);
} }
/** /**
* Performs checks on the retrieved source and returns an {@link Optional} with the exception * Performs checks on the retrieved source and returns an {@link Optional} with the exception
* if there is an issue * if there is an issue that makes the retrieved token unsuitable to be refreshed
*/ */
private Optional<ElasticsearchSecurityException> checkTokenDocForRefresh(Map<String, Object> source, Authentication clientAuth) { private Optional<ElasticsearchSecurityException> checkTokenDocForRefresh(Map<String, Object> source, Authentication clientAuth) {
final Map<String, Object> refreshTokenSrc = (Map<String, Object>) source.get("refresh_token"); final Map<String, Object> refreshTokenSrc = (Map<String, Object>) source.get("refresh_token");
@ -805,8 +997,6 @@ public final class TokenService {
return Optional.of(invalidGrantException("token document is missing invalidated value")); return Optional.of(invalidGrantException("token document is missing invalidated value"));
} else if (creationEpochMilli == null) { } else if (creationEpochMilli == null) {
return Optional.of(invalidGrantException("token document is missing creation time value")); return Optional.of(invalidGrantException("token document is missing creation time value"));
} else if (refreshed) {
return Optional.of(invalidGrantException("token has already been refreshed"));
} else if (invalidated) { } else if (invalidated) {
return Optional.of(invalidGrantException("token has been invalidated")); return Optional.of(invalidGrantException("token has been invalidated"));
} else if (clock.instant().isAfter(creationTime.plus(24L, ChronoUnit.HOURS))) { } else if (clock.instant().isAfter(creationTime.plus(24L, ChronoUnit.HOURS))) {
@ -820,7 +1010,7 @@ public final class TokenService {
} else if (userTokenSrc.get("metadata") == null) { } else if (userTokenSrc.get("metadata") == null) {
return Optional.of(invalidGrantException("token is missing metadata")); return Optional.of(invalidGrantException("token is missing metadata"));
} else { } else {
return checkClient(refreshTokenSrc, clientAuth); return checkLenientlyIfTokenAlreadyRefreshed(source, clientAuth);
} }
} }
} }
@ -830,14 +1020,79 @@ public final class TokenService {
if (clientInfo == null) { if (clientInfo == null) {
return Optional.of(invalidGrantException("token is missing client information")); return Optional.of(invalidGrantException("token is missing client information"));
} else if (clientAuth.getUser().principal().equals(clientInfo.get("user")) == false) { } else if (clientAuth.getUser().principal().equals(clientInfo.get("user")) == false) {
logger.warn("Token was originally created by [{}] but [{}] attempted to refresh it", clientInfo.get("user"),
clientAuth.getUser().principal());
return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client"));
} else if (clientAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) { } else if (clientAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) {
logger.warn("[{}] created the refresh token while authenticated by [{}] but is now authenticated by [{}]",
clientInfo.get("user"), clientInfo.get("realm"), clientAuth.getAuthenticatedBy().getName());
return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); return Optional.of(invalidGrantException("tokens must be refreshed by the creating client"));
} else { } else {
return Optional.empty(); return Optional.empty();
} }
} }
/**
* Checks if the retrieved refresh token is already refreshed taking into consideration that we allow refresh tokens
* to be refreshed multiple times for a very small time window in order to gracefully handle multiple concurrent requests
* from clients
*/
@SuppressWarnings("unchecked")
private Optional<ElasticsearchSecurityException> checkLenientlyIfTokenAlreadyRefreshed(Map<String, Object> source,
Authentication userAuth) {
final Map<String, Object> refreshTokenSrc = (Map<String, Object>) source.get("refresh_token");
final Map<String, Object> userTokenSource = (Map<String, Object>)
((Map<String, Object>) source.get("access_token")).get("user_token");
final Integer version = (Integer) userTokenSource.get("version");
Version authVersion = Version.fromId(version);
final Boolean refreshed = (Boolean) refreshTokenSrc.get("refreshed");
if (refreshed) {
if (authVersion.onOrAfter(Version.V_7_1_0)) {
final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time");
final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli);
final String supersededBy = (String) refreshTokenSrc.get("superseded_by");
if (supersededBy == null) {
return Optional.of(invalidGrantException("token document is missing superseded by value"));
} else if (refreshTime == null) {
return Optional.of(invalidGrantException("token document is missing refresh time value"));
} else if (clock.instant().isAfter(refreshTime.plus(30L, ChronoUnit.SECONDS))) {
return Optional.of(invalidGrantException("token has already been refreshed more than 30 seconds in the past"));
}
} else {
return Optional.of(invalidGrantException("token has already been refreshed"));
}
}
return checkClient(refreshTokenSrc, userAuth);
}
/**
* Checks if a refreshed token is eligible to be refreshed again. This is only allowed for versions after 7.1.0 and
* when the refresh_token contains the refresh_time and superseded_by fields and it has been refreshed in a specific
* time period of 60 seconds. The period is defined as 30 seconds before the token was refreshed until 30 seconds after. The
* time window needs to handle instants before the request time as we capture an instant early on in
* {@link TokenService#refreshToken(String, ActionListener)} and in the case of multiple concurrent requests,
* the {@code refreshRequested} when dealing with one of the subsequent requests might well be <em>before</em> the instant when
* the first of the requests refreshed the token.
*
* @param source The source of the token document that contains the originally refreshed token
* @param refreshRequested The instant when the this refresh request was acknowledged by the TokenService
*/
private boolean eligibleForMultiRefresh(Map<String, Object> source, Instant refreshRequested) {
final Map<String, Object> refreshTokenSrc = (Map<String, Object>) source.get("refresh_token");
final Map<String, Object> userTokenSource = (Map<String, Object>)
((Map<String, Object>) source.get("access_token")).get("user_token");
final Integer version = (Integer) userTokenSource.get("version");
Version authVersion = Version.fromId(version);
final Long refreshedEpochMilli = (Long) refreshTokenSrc.get("refresh_time");
final Instant refreshTime = refreshedEpochMilli == null ? null : Instant.ofEpochMilli(refreshedEpochMilli);
final String supersededBy = (String) refreshTokenSrc.get("superseded_by");
return authVersion.onOrAfter(Version.V_7_1_0)
&& supersededBy != null
&& refreshTime != null
&& refreshRequested.isBefore(refreshTime.plus(30L, ChronoUnit.SECONDS))
&& refreshRequested.isAfter(refreshTime.minus(30L, ChronoUnit.SECONDS));
}
/** /**
* Find stored refresh and access tokens that have not been invalidated or expired, and were issued against * Find stored refresh and access tokens that have not been invalidated or expired, and were issued against
* the specified realm. * the specified realm.
@ -893,7 +1148,6 @@ public final class TokenService {
*/ */
public void findActiveTokensForUser(String username, ActionListener<Collection<Tuple<UserToken, String>>> listener) { public void findActiveTokensForUser(String username, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
ensureEnabled(); ensureEnabled();
final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
if (Strings.isNullOrEmpty(username)) { if (Strings.isNullOrEmpty(username)) {
listener.onFailure(new IllegalArgumentException("username is required")); listener.onFailure(new IllegalArgumentException("username is required"));
@ -958,7 +1212,6 @@ public final class TokenService {
} }
/** /**
*
* Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token * Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token
* *
* @param source The token document source as retrieved * @param source The token document source as retrieved
@ -968,7 +1221,6 @@ public final class TokenService {
*/ */
private Tuple<UserToken, String> parseTokensFromDocument(Map<String, Object> source, @Nullable Predicate<Map<String, Object>> filter) private Tuple<UserToken, String> parseTokensFromDocument(Map<String, Object> source, @Nullable Predicate<Map<String, Object>> filter)
throws IOException { throws IOException {
final String refreshToken = (String) ((Map<String, Object>) source.get("refresh_token")).get("token"); final String refreshToken = (String) ((Map<String, Object>) source.get("refresh_token")).get("token");
final Map<String, Object> userTokenSource = (Map<String, Object>) final Map<String, Object> userTokenSource = (Map<String, Object>)
((Map<String, Object>) source.get("access_token")).get("user_token"); ((Map<String, Object>) source.get("access_token")).get("user_token");
@ -1063,7 +1315,6 @@ public final class TokenService {
} }
} }
public TimeValue getExpirationDelay() { public TimeValue getExpirationDelay() {
return expirationDelay; return expirationDelay;
} }
@ -1095,9 +1346,21 @@ public final class TokenService {
} }
/** /**
* Serializes a token to a String containing an encrypted representation of the token * Serializes a token to a String containing the version of the node that created the token and
* either an encrypted representation of the token id for versions earlier to 7.0.0 or the token ie
* itself for versions after 7.0.0
*/ */
public String getUserTokenString(UserToken userToken) throws IOException, GeneralSecurityException { public String getAccessTokenAsString(UserToken userToken) throws IOException, GeneralSecurityException {
if (clusterService.state().nodes().getMinNodeVersion().onOrAfter(Version.V_7_1_0)) {
try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os);
StreamOutput out = new OutputStreamStreamOutput(base64)) {
out.setVersion(userToken.getVersion());
Version.writeVersion(userToken.getVersion(), out);
out.writeString(userToken.getId());
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}
} else {
// we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly // we know that the minimum length is larger than the default of the ByteArrayOutputStream so set the size to this explicitly
try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os); OutputStream base64 = Base64.getEncoder().wrap(os);
@ -1119,13 +1382,15 @@ public final class TokenService {
} }
} }
} }
}
private void ensureEncryptionCiphersSupported() throws NoSuchPaddingException, NoSuchAlgorithmException { private void ensureEncryptionCiphersSupported() throws NoSuchPaddingException, NoSuchAlgorithmException {
Cipher.getInstance(ENCRYPTION_CIPHER); Cipher.getInstance(ENCRYPTION_CIPHER);
SecretKeyFactory.getInstance(KDF_ALGORITHM); SecretKeyFactory.getInstance(KDF_ALGORITHM);
} }
private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { // Package private for testing
Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER);
BytesKey salt = keyAndCache.getSalt(); BytesKey salt = keyAndCache.getSalt();
try { try {
@ -1147,7 +1412,8 @@ public final class TokenService {
return cipher; return cipher;
} }
private byte[] getNewInitializationVector() { // Package private for testing
byte[] getNewInitializationVector() {
final byte[] initializationVector = new byte[IV_BYTES]; final byte[] initializationVector = new byte[IV_BYTES];
secureRandom.nextBytes(initializationVector); secureRandom.nextBytes(initializationVector);
return initializationVector; return initializationVector;
@ -1528,12 +1794,19 @@ public final class TokenService {
} }
/** /**
* For testing * Package private for testing
*/ */
void clearActiveKeyCache() { void clearActiveKeyCache() {
this.keyCache.activeKeyCache.keyCache.invalidateAll(); this.keyCache.activeKeyCache.keyCache.invalidateAll();
} }
/**
* Package private for testing
*/
KeyAndCache getActiveKeyCache() {
return this.keyCache.activeKeyCache;
}
static final class KeyAndCache implements Closeable { static final class KeyAndCache implements Closeable {
private final KeyAndTimestamp keyAndTimestamp; private final KeyAndTimestamp keyAndTimestamp;
private final Cache<BytesKey, SecretKey> keyCache; private final Cache<BytesKey, SecretKey> keyCache;

View File

@ -242,7 +242,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true); tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true);
final UserToken userToken = future.actionGet().v1(); final UserToken userToken = future.actionGet().v1();
mockGetTokenFromId(userToken, false, client); mockGetTokenFromId(userToken, false, client);
final String tokenString = tokenService.getUserTokenString(userToken); final String tokenString = tokenService.getAccessTokenAsString(userToken);
final SamlLogoutRequest request = new SamlLogoutRequest(); final SamlLogoutRequest request = new SamlLogoutRequest();
request.setToken(tokenString); request.setToken(tokenString);

View File

@ -1109,7 +1109,7 @@ public class AuthenticationServiceTests extends ESTestCase {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
} }
String token = tokenService.getUserTokenString(tokenFuture.get().v1()); String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1());
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
mockGetTokenFromId(tokenFuture.get().v1(), false, client); mockGetTokenFromId(tokenFuture.get().v1(), false, client);
when(securityIndex.isAvailable()).thenReturn(true); when(securityIndex.isAvailable()).thenReturn(true);
@ -1192,7 +1192,7 @@ public class AuthenticationServiceTests extends ESTestCase {
Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null);
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
} }
String token = tokenService.getUserTokenString(tokenFuture.get().v1()); String token = tokenService.getAccessTokenAsString(tokenFuture.get().v1());
mockGetTokenFromId(tokenFuture.get().v1(), true, client); mockGetTokenFromId(tokenFuture.get().v1(), true, client);
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[1]).run(); ((Runnable) invocationOnMock.getArguments()[1]).run();

View File

@ -5,11 +5,14 @@
*/ */
package org.elasticsearch.xpack.security.authc; package org.elasticsearch.xpack.security.authc;
import org.apache.directory.api.util.Strings;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse; import org.elasticsearch.action.admin.cluster.state.ClusterStateResponse;
import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.support.WriteRequest; import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client; import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse;
import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SecureString;
@ -23,6 +26,7 @@ import org.elasticsearch.test.SecuritySettingsSource;
import org.elasticsearch.test.SecuritySettingsSourceField; import org.elasticsearch.test.SecuritySettingsSourceField;
import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.test.junit.annotations.TestLogging;
import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.CreateTokenResponse;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse;
@ -38,7 +42,13 @@ import org.junit.Before;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
@ -330,7 +340,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); assertEquals("token has been invalidated", e.getHeader("error_description").get(0));
} }
public void testRefreshingMultipleTimes() { public void testRefreshingMultipleTimesFails() throws Exception {
Client client = client().filterWithHeader(Collections.singletonMap("Authorization", Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
@ -343,12 +353,101 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
assertNotNull(createTokenResponse.getRefreshToken()); assertNotNull(createTokenResponse.getRefreshToken());
CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get();
assertNotNull(refreshResponse); assertNotNull(refreshResponse);
// We now have two documents, the original(now refreshed) token doc and the new one with the new access doc
AtomicReference<String> docId = new AtomicReference<>();
assertBusy(() -> {
SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setSource(SearchSourceBuilder.searchSource()
.query(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("doc_type", "token"))
.must(QueryBuilders.termQuery("refresh_token.refreshed", "true"))))
.setSize(1)
.setTerminateAfter(1)
.get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo(1L));
docId.set(searchResponse.getHits().getAt(0).getId());
});
// hack doc to modify the refresh time to 50 seconds ago so that we don't hit the lenient refresh case
Instant refreshed = Instant.now();
Instant aWhileAgo = refreshed.minus(50L, ChronoUnit.SECONDS);
assertTrue(Instant.now().isAfter(aWhileAgo));
UpdateResponse updateResponse = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get())
.setDoc("refresh_token", Collections.singletonMap("refresh_time", aWhileAgo.toEpochMilli()))
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.setFetchSource("refresh_token", Strings.EMPTY_STRING)
.get();
assertNotNull(updateResponse);
Map<String, Object> refreshTokenMap = (Map<String, Object>) updateResponse.getGetResult().sourceAsMap().get("refresh_token");
assertTrue(
Instant.ofEpochMilli((long) refreshTokenMap.get("refresh_time")).isBefore(Instant.now().minus(30L, ChronoUnit.SECONDS)));
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get());
assertEquals("invalid_grant", e.getMessage()); assertEquals("invalid_grant", e.getMessage());
assertEquals(RestStatus.BAD_REQUEST, e.status()); assertEquals(RestStatus.BAD_REQUEST, e.status());
assertEquals("token has already been refreshed", e.getHeader("error_description").get(0)); assertEquals("token has already been refreshed more than 30 seconds in the past", e.getHeader("error_description").get(0));
}
public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception {
Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient securityClient = new SecurityClient(client);
Set<String> refreshTokens = new HashSet<>();
Set<String> accessTokens = new HashSet<>();
CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken()
.setGrantType("password")
.setUsername(SecuritySettingsSource.TEST_USER_NAME)
.setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()))
.get();
assertNotNull(createTokenResponse.getRefreshToken());
final int numberOfProcessors = Runtime.getRuntime().availableProcessors();
final int numberOfThreads = scaledRandomIntBetween((numberOfProcessors + 1) / 2, numberOfProcessors * 3);
List<Thread> threads = new ArrayList<>(numberOfThreads);
final CountDownLatch readyLatch = new CountDownLatch(numberOfThreads + 1);
final CountDownLatch completedLatch = new CountDownLatch(numberOfThreads);
AtomicBoolean failed = new AtomicBoolean();
for (int i = 0; i < numberOfThreads; i++) {
threads.add(new Thread(() -> {
// Each thread gets its own client so that more than one nodes will be hit
Client threadClient = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient threadSecurityClient = new SecurityClient(threadClient);
CreateTokenRequest refreshRequest =
threadSecurityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).request();
readyLatch.countDown();
try {
readyLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
completedLatch.countDown();
return;
}
threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> {
accessTokens.add(result.getTokenString());
refreshTokens.add(result.getRefreshToken());
logger.info("received access token [{}] and refresh token [{}]", result.getTokenString(), result.getRefreshToken());
completedLatch.countDown();
}, e -> {
failed.set(true);
completedLatch.countDown();
logger.error("caught exception", e);
}));
}));
}
for (Thread thread : threads) {
thread.start();
}
readyLatch.countDown();
readyLatch.await();
for (Thread thread : threads) {
thread.join();
}
completedLatch.await();
assertThat(failed.get(), equalTo(false));
assertThat(accessTokens.size(), equalTo(1));
assertThat(refreshTokens.size(), equalTo(1));
} }
public void testRefreshAsDifferentUser() { public void testRefreshAsDifferentUser() {

View File

@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.authc; package org.elasticsearch.xpack.security.authc;
import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.get.GetAction; import org.elasticsearch.action.get.GetAction;
@ -23,6 +24,8 @@ import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings; import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.concurrent.ThreadContext;
@ -51,7 +54,11 @@ import org.junit.AfterClass;
import org.junit.Before; import org.junit.Before;
import org.junit.BeforeClass; import org.junit.BeforeClass;
import java.io.ByteArrayOutputStream;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.time.Clock; import java.time.Clock;
import java.time.Instant; import java.time.Instant;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
@ -61,6 +68,7 @@ import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.crypto.CipherOutputStream;
import javax.crypto.SecretKey; import javax.crypto.SecretKey;
import static java.time.Clock.systemUTC; import static java.time.Clock.systemUTC;
@ -151,7 +159,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getAccessTokenAsString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -198,7 +206,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -219,10 +227,10 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1(); final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken); assertNotNull(newToken);
assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));
requestContext = new ThreadContext(Settings.EMPTY); requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken));
mockGetTokenFromId(newToken, false); mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
@ -258,7 +266,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
otherTokenService.getAndValidateToken(requestContext, future); otherTokenService.getAndValidateToken(requestContext, future);
@ -289,7 +297,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -316,7 +324,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1(); final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken); assertNotNull(newToken);
assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));
metaData = tokenService.pruneKeys(1); metaData = tokenService.pruneKeys(1);
tokenService.refreshMetaData(metaData); tokenService.refreshMetaData(metaData);
@ -329,7 +337,7 @@ public class TokenServiceTests extends ESTestCase {
} }
requestContext = new ThreadContext(Settings.EMPTY); requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken));
mockGetTokenFromId(newToken, false); mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -351,7 +359,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -362,7 +370,7 @@ public class TokenServiceTests extends ESTestCase {
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
// verify a second separate token service with its own passphrase cannot verify // verify a second separate token service with its own passphrase cannot verify
TokenService anotherService = new TokenService(Settings.EMPTY, systemUTC(), client, securityIndex, TokenService anotherService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, securityIndex,
clusterService); clusterService);
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
anotherService.getAndValidateToken(requestContext, future); anotherService.getAndValidateToken(requestContext, future);
@ -377,10 +385,10 @@ public class TokenServiceTests extends ESTestCase {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>(); PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true); tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
UserToken token = tokenFuture.get().v1(); UserToken token = tokenFuture.get().v1();
assertThat(tokenService.getUserTokenString(token), notNullValue()); assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue());
tokenService.clearActiveKeyCache(); tokenService.clearActiveKeyCache();
assertThat(tokenService.getUserTokenString(token), notNullValue()); assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue());
} }
public void testInvalidatedToken() throws Exception { public void testInvalidatedToken() throws Exception {
@ -395,7 +403,7 @@ public class TokenServiceTests extends ESTestCase {
mockGetTokenFromId(token, true); mockGetTokenFromId(token, true);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>(); PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -449,7 +457,7 @@ public class TokenServiceTests extends ESTestCase {
authentication = token.getAuthentication(); authentication = token.getAuthentication();
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
// the clock is still frozen, so the cookie should be valid // the clock is still frozen, so the cookie should be valid
@ -559,7 +567,7 @@ public class TokenServiceTests extends ESTestCase {
//mockGetTokenFromId(token, false); //mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY); ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); requestContext.putHeader("Authorization", "Bearer " + tokenService.getAccessTokenAsString(token));
doAnswer(invocationOnMock -> { doAnswer(invocationOnMock -> {
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1]; ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
@ -598,7 +606,7 @@ public class TokenServiceTests extends ESTestCase {
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS)); UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS));
mockGetTokenFromId(expired, false); mockGetTokenFromId(expired, false);
String userTokenString = tokenService.getUserTokenString(expired); String userTokenString = tokenService.getAccessTokenAsString(expired);
PlainActionFuture<Tuple<Authentication, Map<String, Object>>> authFuture = new PlainActionFuture<>(); PlainActionFuture<Tuple<Authentication, Map<String, Object>>> authFuture = new PlainActionFuture<>();
tokenService.getAuthenticationAndMetaData(userTokenString, authFuture); tokenService.getAuthenticationAndMetaData(userTokenString, authFuture);
Authentication retrievedAuth = authFuture.actionGet().v1(); Authentication retrievedAuth = authFuture.actionGet().v1();
@ -639,4 +647,28 @@ public class TokenServiceTests extends ESTestCase {
assertEquals(expected.getMetadata(), result.getMetadata()); assertEquals(expected.getMetadata(), result.getMetadata());
assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType()); assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType());
} }
protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException,
GeneralSecurityException {
try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os);
StreamOutput out = new OutputStreamStreamOutput(base64)) {
out.setVersion(Version.V_7_0_0);
TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache();
Version.writeVersion(Version.V_7_0_0, out);
out.writeByteArray(keyAndCache.getSalt().bytes);
out.writeByteArray(keyAndCache.getKeyHash().bytes);
final byte[] initializationVector = tokenService.getNewInitializationVector();
out.writeByteArray(initializationVector);
try (CipherOutputStream encryptedOutput =
new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0));
StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) {
encryptedStreamOutput.setVersion(Version.V_7_0_0);
encryptedStreamOutput.writeString(userToken.getId());
encryptedStreamOutput.close();
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}
}
}
} }

View File

@ -25,8 +25,7 @@ public class TokensInvalidationResultTests extends ESTestCase {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"),
Arrays.asList("token3", "token4"), Arrays.asList("token3", "token4"),
Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")),
new ElasticsearchException("boo", new IllegalStateException("far"))), new ElasticsearchException("boo", new IllegalStateException("far"))));
randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) { try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS); result.toXContent(builder, ToXContent.EMPTY_PARAMS);
@ -56,9 +55,8 @@ public class TokensInvalidationResultTests extends ESTestCase {
} }
public void testToXcontentWithNoErrors() throws Exception{ public void testToXcontentWithNoErrors() throws Exception{
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), Collections.emptyList(),
Collections.emptyList(), Collections.emptyList());
Collections.emptyList(), randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) { try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS); result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder), assertThat(Strings.toString(builder),