Remove bwc logic for token invalidation (#36893)

- Removes bwc invalidation logic from the TokenService
- Removes bwc serialization for InvalidateTokenResponse objects as
    old nodes in supported mixed clusters during upgrade will be 6.7 and
    thus will know of the new format
- Removes the created field from the TokensInvalidationResult and the
    InvalidateTokenResponse as it is no longer useful in > 7.0
This commit is contained in:
Ioannis Kakavas 2018-12-28 13:09:42 +02:00 committed by GitHub
parent 44bd7db59e
commit 0cae979dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 172 additions and 580 deletions

View File

@ -42,13 +42,11 @@ import static org.elasticsearch.common.xcontent.ConstructingObjectParser.optiona
*/
public final class InvalidateTokenResponse {
public static final ParseField CREATED = new ParseField("created");
public static final ParseField INVALIDATED_TOKENS = new ParseField("invalidated_tokens");
public static final ParseField PREVIOUSLY_INVALIDATED_TOKENS = new ParseField("previously_invalidated_tokens");
public static final ParseField ERROR_COUNT = new ParseField("error_count");
public static final ParseField ERRORS = new ParseField("error_details");
private final boolean created;
private final int invalidatedTokens;
private final int previouslyInvalidatedTokens;
private List<ElasticsearchException> errors;
@ -57,19 +55,17 @@ public final class InvalidateTokenResponse {
private static final ConstructingObjectParser<InvalidateTokenResponse, Void> PARSER = new ConstructingObjectParser<>(
"tokens_invalidation_result", true,
// we parse but do not use the count of errors as we implicitly have this in the size of the Exceptions list
args -> new InvalidateTokenResponse((boolean) args[0], (int) args[1], (int) args[2], (List<ElasticsearchException>) args[4]));
args -> new InvalidateTokenResponse((int) args[0], (int) args[1], (List<ElasticsearchException>) args[3]));
static {
PARSER.declareBoolean(constructorArg(), CREATED);
PARSER.declareInt(constructorArg(), INVALIDATED_TOKENS);
PARSER.declareInt(constructorArg(), PREVIOUSLY_INVALIDATED_TOKENS);
PARSER.declareInt(constructorArg(), ERROR_COUNT);
PARSER.declareObjectArray(optionalConstructorArg(), (p, c) -> ElasticsearchException.fromXContent(p), ERRORS);
}
public InvalidateTokenResponse(boolean created, int invalidatedTokens, int previouslyInvalidatedTokens,
public InvalidateTokenResponse(int invalidatedTokens, int previouslyInvalidatedTokens,
@Nullable List<ElasticsearchException> errors) {
this.created = created;
this.invalidatedTokens = invalidatedTokens;
this.previouslyInvalidatedTokens = previouslyInvalidatedTokens;
if (null == errors) {
@ -79,10 +75,6 @@ public final class InvalidateTokenResponse {
}
}
public boolean isCreated() {
return created;
}
public int getInvalidatedTokens() {
return invalidatedTokens;
}
@ -104,15 +96,14 @@ public final class InvalidateTokenResponse {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InvalidateTokenResponse that = (InvalidateTokenResponse) o;
return created == that.created &&
invalidatedTokens == that.invalidatedTokens &&
return invalidatedTokens == that.invalidatedTokens &&
previouslyInvalidatedTokens == that.previouslyInvalidatedTokens &&
Objects.equals(errors, that.errors);
}
@Override
public int hashCode() {
return Objects.hash(created, invalidatedTokens, previouslyInvalidatedTokens, errors);
return Objects.hash(invalidatedTokens, previouslyInvalidatedTokens, errors);
}
public static InvalidateTokenResponse fromXContent(XContentParser parser) throws IOException {

View File

@ -41,7 +41,6 @@ public class InvalidateTokenResponseTests extends ESTestCase {
final int invalidatedTokens = randomInt(32);
final int previouslyInvalidatedTokens = randomInt(32);
builder.startObject()
.field("created", false)
.field("invalidated_tokens", invalidatedTokens)
.field("previously_invalidated_tokens", previouslyInvalidatedTokens)
.field("error_count", 0)
@ -50,7 +49,6 @@ public class InvalidateTokenResponseTests extends ESTestCase {
try (XContentParser parser = createParser(xContentType.xContent(), xContent)) {
final InvalidateTokenResponse response = InvalidateTokenResponse.fromXContent(parser);
assertThat(response.isCreated(), Matchers.equalTo(false));
assertThat(response.getInvalidatedTokens(), Matchers.equalTo(invalidatedTokens));
assertThat(response.getPreviouslyInvalidatedTokens(), Matchers.equalTo(previouslyInvalidatedTokens));
assertThat(response.getErrorsCount(), Matchers.equalTo(0));
@ -64,7 +62,6 @@ public class InvalidateTokenResponseTests extends ESTestCase {
final int invalidatedTokens = randomInt(32);
final int previouslyInvalidatedTokens = randomInt(32);
builder.startObject()
.field("created", false)
.field("invalidated_tokens", invalidatedTokens)
.field("previously_invalidated_tokens", previouslyInvalidatedTokens)
.field("error_count", 0)
@ -82,7 +79,6 @@ public class InvalidateTokenResponseTests extends ESTestCase {
try (XContentParser parser = createParser(xContentType.xContent(), xContent)) {
final InvalidateTokenResponse response = InvalidateTokenResponse.fromXContent(parser);
assertThat(response.isCreated(), Matchers.equalTo(false));
assertThat(response.getInvalidatedTokens(), Matchers.equalTo(invalidatedTokens));
assertThat(response.getPreviouslyInvalidatedTokens(), Matchers.equalTo(previouslyInvalidatedTokens));
assertThat(response.getErrorsCount(), Matchers.equalTo(2));

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Nullable;
@ -137,57 +136,19 @@ public final class InvalidateTokenRequest extends ActionRequest {
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getVersion().before(Version.V_7_0_0)) {
if (Strings.isNullOrEmpty(tokenString)) {
throw new IllegalArgumentException("token is required for versions < v6.6.0");
}
out.writeString(tokenString);
} else {
out.writeOptionalString(tokenString);
}
if (out.getVersion().onOrAfter(Version.V_6_2_0)) {
if (out.getVersion().before(Version.V_7_0_0)) {
if (tokenType == null) {
throw new IllegalArgumentException("token type is not optional for versions > v6.2.0 and < v6.6.0");
}
out.writeVInt(tokenType.ordinal());
} else {
out.writeOptionalVInt(tokenType == null ? null : tokenType.ordinal());
}
} else if (tokenType == Type.REFRESH_TOKEN) {
throw new IllegalArgumentException("refresh token invalidation cannot be serialized with version [" + out.getVersion() + "]");
}
if (out.getVersion().onOrAfter(Version.V_7_0_0)) {
out.writeOptionalString(realmName);
out.writeOptionalString(userName);
} else if (realmName != null || userName != null) {
throw new IllegalArgumentException(
"realm or user token invalidation cannot be serialized with version [" + out.getVersion() + "]");
}
out.writeOptionalString(tokenString);
out.writeOptionalVInt(tokenType == null ? null : tokenType.ordinal());
out.writeOptionalString(realmName);
out.writeOptionalString(userName);
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
if (in.getVersion().before(Version.V_7_0_0)) {
tokenString = in.readString();
} else {
tokenString = in.readOptionalString();
}
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
if (in.getVersion().before(Version.V_7_0_0)) {
int type = in.readVInt();
tokenType = Type.values()[type];
} else {
Integer type = in.readOptionalVInt();
tokenType = type == null ? null : Type.values()[type];
}
} else {
tokenType = Type.ACCESS_TOKEN;
}
if (in.getVersion().onOrAfter(Version.V_7_0_0)) {
realmName = in.readOptionalString();
userName = in.readOptionalString();
}
tokenString = in.readOptionalString();
Integer type = in.readOptionalVInt();
tokenType = type == null ? null : Type.values()[type];
realmName = in.readOptionalString();
userName = in.readOptionalString();
}
}

View File

@ -5,7 +5,6 @@
*/
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -14,8 +13,6 @@ import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Objects;
/**
@ -35,35 +32,16 @@ public final class InvalidateTokenResponse extends ActionResponse implements ToX
return result;
}
private boolean isCreated() {
return result.getInvalidatedTokens().size() > 0
&& result.getPreviouslyInvalidatedTokens().isEmpty()
&& result.getErrors().isEmpty();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
if (out.getVersion().before(Version.V_7_0_0)) {
out.writeBoolean(isCreated());
} else {
result.writeTo(out);
}
result.writeTo(out);
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
if (in.getVersion().before(Version.V_7_0_0)) {
final boolean created = in.readBoolean();
if (created) {
result = new TokensInvalidationResult(Arrays.asList(""), Collections.emptyList(), Collections.emptyList(), 0);
} else {
result = new TokensInvalidationResult(Collections.emptyList(), Arrays.asList(""), Collections.emptyList(), 0);
}
} else {
result = new TokensInvalidationResult(in);
}
result = new TokensInvalidationResult(in);
}
@Override

View File

@ -79,8 +79,6 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject()
//Remove created after PR is backported to 6.x
.field("created", isCreated())
.field("invalidated_tokens", invalidatedTokens.size())
.field("previously_invalidated_tokens", previouslyInvalidatedTokens.size())
.field("error_count", errors.size());
@ -104,10 +102,4 @@ public class TokensInvalidationResult implements ToXContentObject, Writeable {
out.writeCollection(errors, StreamOutput::writeException);
out.writeVInt(attemptCount);
}
private boolean isCreated() {
return this.getInvalidatedTokens().size() > 0
&& this.getPreviouslyInvalidatedTokens().isEmpty()
&& this.getErrors().isEmpty();
}
}

View File

@ -6,7 +6,6 @@
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
@ -14,7 +13,6 @@ import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import java.io.IOException;
@ -65,48 +63,6 @@ public class InvalidateTokenResponseTests extends ESTestCase {
}
}
public void testSerializationToPre66Version() throws IOException{
final Version version = VersionUtils.randomVersionBetween(random(), Version.V_6_2_0, Version.V_6_5_1);
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
// False as we have errors and previously invalidated tokens
assertThat(input.readBoolean(), equalTo(false));
}
}
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
// False as we have previously invalidated tokens
assertThat(input.readBoolean(), equalTo(false));
}
}
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Collections.emptyList(), Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
assertThat(input.readBoolean(), equalTo(true));
}
}
}
public void testToXContent() throws IOException {
List invalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
@ -118,7 +74,7 @@ public class InvalidateTokenResponseTests extends ESTestCase {
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo("{\"created\":false," +
equalTo("{" +
"\"invalidated_tokens\":" + invalidatedTokens.size() + "," +
"\"previously_invalidated_tokens\":" + previouslyInvalidatedTokens.size() + "," +
"\"error_count\":2," +

View File

@ -33,7 +33,9 @@ import static org.elasticsearch.xpack.core.ClientHelper.SECURITY_ORIGIN;
import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin;
/**
* Responsible for cleaning the invalidated tokens from the invalidated tokens index.
* Responsible for cleaning the invalidated and expired tokens from the security index.
* The document gets deleted if it was created more than 24 hours which is the maximum
* lifetime of a refresh token
*/
final class ExpiredTokenRemover extends AbstractRunnable {
private static final Logger logger = LogManager.getLogger(ExpiredTokenRemover.class);
@ -57,10 +59,8 @@ final class ExpiredTokenRemover extends AbstractRunnable {
final Instant now = Instant.now();
expiredDbq
.setQuery(QueryBuilders.boolQuery()
.filter(QueryBuilders.termsQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE, "token"))
.filter(QueryBuilders.boolQuery()
.should(QueryBuilders.rangeQuery("expiration_time").lte(now.toEpochMilli()))
.should(QueryBuilders.rangeQuery("creation_time").lte(now.minus(24L, ChronoUnit.HOURS).toEpochMilli()))));
.filter(QueryBuilders.termsQuery("doc_type", "token"))
.filter(QueryBuilders.rangeQuery("creation_time").lte(now.minus(24L, ChronoUnit.HOURS).toEpochMilli())));
logger.trace(() -> new ParameterizedMessage("Removing old tokens: [{}]", Strings.toString(expiredDbq)));
executeAsyncWithOrigin(client, SECURITY_ORIGIN, DeleteByQueryAction.INSTANCE, expiredDbq,
ActionListener.wrap(r -> {

View File

@ -19,14 +19,10 @@ import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest.OpType;
import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.get.MultiGetRequest;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.search.SearchRequest;
@ -165,7 +161,8 @@ public final class TokenService {
public static final Setting<TimeValue> DELETE_TIMEOUT = Setting.timeSetting("xpack.security.authc.token.delete.timeout",
TimeValue.MINUS_ONE, Property.NodeScope);
static final String INVALIDATED_TOKEN_DOC_TYPE = "invalidated-token";
private static final String TOKEN_DOC_TYPE = "token";
private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_";
static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1;
private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue();
private static final int MAX_RETRY_ATTEMPTS = 5;
@ -245,7 +242,7 @@ public final class TokenService {
try (XContentBuilder builder = XContentFactory.jsonBuilder()) {
builder.startObject();
builder.field("doc_type", "token");
builder.field("doc_type", TOKEN_DOC_TYPE);
builder.field("creation_time", created.toEpochMilli());
if (includeRefreshToken) {
builder.startObject("refresh_token")
@ -293,15 +290,13 @@ public final class TokenService {
listener.onResponse(null);
} else {
try {
decodeAndValidateToken(token, ActionListener.wrap(listener::onResponse, e -> {
if (e instanceof IOException) {
// could happen with a token that is not ours
logger.debug("invalid token", e);
listener.onResponse(null);
decodeToken(token, ActionListener.wrap(userToken -> {
if (userToken != null) {
checkIfTokenIsValid(userToken, listener);
} else {
listener.onFailure(e);
listener.onResponse(null);
}
}));
}, listener::onFailure));
} catch (IOException e) {
// could happen with a token that is not ours
logger.debug("invalid token", e);
@ -331,22 +326,6 @@ public final class TokenService {
));
}
private void decodeAndValidateToken(String token, ActionListener<UserToken> listener) throws IOException {
decodeToken(token, ActionListener.wrap(userToken -> {
if (userToken != null) {
Instant currentTime = clock.instant();
if (currentTime.isAfter(userToken.getExpirationTime())) {
// token expired
listener.onFailure(traceLog("decode token", token, expiredTokenException()));
} else {
checkIfTokenIsRevoked(userToken, listener);
}
} else {
listener.onResponse(null);
}
}, listener::onFailure));
}
/*
* Asynchronously decodes the string representation of a {@link UserToken}. The process for
* this is asynchronous as we may need to compute a key, which can be computationally expensive
@ -373,55 +352,51 @@ public final class TokenService {
try {
final byte[] iv = in.readByteArray();
final Cipher cipher = getDecryptionCipher(iv, decodeKey, version, decodedSalt);
if (version.onOrAfter(Version.V_6_2_0)) {
// we only have the id and need to get the token from the doc!
decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> {
if (securityIndex.isAvailable() == false) {
logger.warn("failed to get token [{}] since index is not available", tokenId);
listener.onResponse(null);
} else {
securityIndex.checkIndexVersionThenExecute(
ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)),
() -> {
final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE,
getTokenDocumentId(tokenId)).request();
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest,
ActionListener.<GetResponse>wrap(response -> {
if (response.isExists()) {
Map<String, Object> accessTokenSource =
(Map<String, Object>) response.getSource().get("access_token");
if (accessTokenSource == null) {
onFailure.accept(new IllegalStateException(
"token document is missing the access_token field"));
} else if (accessTokenSource.containsKey("user_token") == false) {
onFailure.accept(new IllegalStateException(
"token document is missing the user_token field"));
} else {
Map<String, Object> userTokenSource =
(Map<String, Object>) accessTokenSource.get("user_token");
listener.onResponse(UserToken.fromSourceMap(userTokenSource));
}
decryptTokenId(in, cipher, version, ActionListener.wrap(tokenId -> {
if (securityIndex.isAvailable() == false) {
logger.warn("failed to get token [{}] since index is not available", tokenId);
listener.onResponse(null);
} else {
securityIndex.checkIndexVersionThenExecute(
ex -> listener.onFailure(traceLog("prepare security index", tokenId, ex)),
() -> {
final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE,
getTokenDocumentId(tokenId)).request();
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("decode token", tokenId, ex));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest,
ActionListener.<GetResponse>wrap(response -> {
if (response.isExists()) {
Map<String, Object> accessTokenSource =
(Map<String, Object>) response.getSource().get("access_token");
if (accessTokenSource == null) {
onFailure.accept(new IllegalStateException(
"token document is missing the access_token field"));
} else if (accessTokenSource.containsKey("user_token") == false) {
onFailure.accept(new IllegalStateException(
"token document is missing the user_token field"));
} else {
onFailure.accept(
new IllegalStateException("token document is missing and must be present"));
Map<String, Object> userTokenSource =
(Map<String, Object>) accessTokenSource.get("user_token");
listener.onResponse(UserToken.fromSourceMap(userTokenSource));
}
}, e -> {
// if the index or the shard is not there / available we assume that
// the token is not valid
if (isShardNotAvailableException(e)) {
logger.warn("failed to get token [{}] since index is not available", tokenId);
listener.onResponse(null);
} else {
logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e);
listener.onFailure(e);
}
}), client::get);
});
}}, listener::onFailure));
} else {
decryptToken(in, cipher, version, listener);
}
} else {
onFailure.accept(
new IllegalStateException("token document is missing and must be present"));
}
}, e -> {
// if the index or the shard is not there / available we assume that
// the token is not valid
if (isShardNotAvailableException(e)) {
logger.warn("failed to get token [{}] since index is not available", tokenId);
listener.onResponse(null);
} else {
logger.error(new ParameterizedMessage("failed to get token [{}]", tokenId), e);
listener.onFailure(e);
}
}), client::get);
});
}
}, listener::onFailure));
} catch (GeneralSecurityException e) {
// could happen with a token that is not ours
logger.warn("invalid token", e);
@ -456,14 +431,6 @@ public final class TokenService {
}
}
private static void decryptToken(StreamInput in, Cipher cipher, Version version, ActionListener<UserToken> listener) throws
IOException {
try (CipherInputStream cis = new CipherInputStream(in, cipher); StreamInput decryptedInput = new InputStreamStreamInput(cis)) {
decryptedInput.setVersion(version);
listener.onResponse(new UserToken(decryptedInput));
}
}
private static void decryptTokenId(StreamInput in, Cipher cipher, Version version, ActionListener<String> listener) throws IOException {
try (CipherInputStream cis = new CipherInputStream(in, cipher); StreamInput decryptedInput = new InputStreamStreamInput(cis)) {
decryptedInput.setVersion(version);
@ -473,10 +440,7 @@ public final class TokenService {
/**
* This method performs the steps necessary to invalidate a token so that it may no longer be
* used. The process of invalidation involves a step that is needed for backwards compatibility
* with versions prior to 6.2.0; this step records an entry to indicate that a token with a
* given id has been expired. The second step is to record the invalidation for tokens that
* have been created on versions on or after 6.2; this step involves performing an update to
* used. The process of invalidation involves performing an update to
* the token document and setting the <code>invalidated</code> field to <code>true</code>
*/
public void invalidateAccessToken(String tokenString, ActionListener<TokensInvalidationResult> listener) {
@ -491,9 +455,8 @@ public final class TokenService {
if (userToken == null) {
listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException()));
} else {
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0),
expirationEpochMilli, null);
indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0),
"access_token", null);
}
}, listener::onFailure));
} catch (IOException e) {
@ -515,8 +478,7 @@ public final class TokenService {
listener.onFailure(new IllegalArgumentException("token must be provided"));
} else {
maybeStartTokenRemover();
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), expirationEpochMilli, null);
indexInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), "access_token", null);
}
}
@ -591,84 +553,14 @@ public final class TokenService {
*/
private void invalidateAllTokens(Collection<String> accessTokenIds, ActionListener<TokensInvalidationResult> listener) {
maybeStartTokenRemover();
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
// Invalidate the refresh tokens first so that they cannot be used to get new
// access tokens while we invalidate the access tokens we currently know about
indexInvalidation(accessTokenIds, ActionListener.wrap(result ->
indexBwcInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()),
expirationEpochMilli, result),
indexInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()),
"access_token", result),
listener::onFailure), new AtomicInteger(0), "refresh_token", null);
}
/**
* Performs the actual bwc invalidation of a collection of tokens and then kicks off the new invalidation method.
*
* @param tokenIds the collection of token ids or token document ids that should be invalidated
* @param listener the listener to notify upon completion
* @param attemptCount the number of attempts to invalidate that have already been tried
* @param expirationEpochMilli the expiration time as milliseconds since the epoch
* @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating
* tokens up to the point of the retry. This result is added to the result of the current attempt
*/
private void indexBwcInvalidation(Collection<String> tokenIds, ActionListener<TokensInvalidationResult> listener,
AtomicInteger attemptCount, long expirationEpochMilli,
@Nullable TokensInvalidationResult previousResult) {
if (tokenIds.isEmpty()) {
logger.warn("No tokens provided for invalidation");
listener.onFailure(invalidGrantException("No tokens provided for invalidation"));
} else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(),
attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate tokens"));
} else {
BulkRequestBuilder bulkRequestBuilder = client.prepareBulk();
for (String tokenId : tokenIds) {
final String invalidatedTokenId = getInvalidatedTokenDocumentId(tokenId);
IndexRequest indexRequest = client.prepareIndex(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId)
.setOpType(OpType.CREATE)
.setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli)
.request();
bulkRequestBuilder.add(indexRequest);
}
bulkRequestBuilder.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL);
final BulkRequest bulkRequest = bulkRequestBuilder.request();
securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequest,
ActionListener.<BulkResponse>wrap(bulkResponse -> {
List<String> retryTokenIds = new ArrayList<>();
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
if (bulkItemResponse.isFailed()) {
Throwable cause = bulkItemResponse.getFailure().getCause();
logger.error(cause.getMessage());
traceLog("(bwc) invalidate tokens", cause);
if (isShardNotAvailableException(cause)) {
retryTokenIds.add(getTokenIdFromInvalidatedTokenDocumentId(bulkItemResponse.getFailure().getId()));
} else if ((cause instanceof VersionConflictEngineException) == false){
// We don't handle VersionConflictEngineException, the ticket has been invalidated
listener.onFailure(bulkItemResponse.getFailure().getCause());
}
}
}
if (retryTokenIds.isEmpty() == false) {
attemptCount.incrementAndGet();
indexBwcInvalidation(retryTokenIds, listener, attemptCount, expirationEpochMilli, previousResult);
}
indexInvalidation(tokenIds, listener, attemptCount, "access_token", previousResult);
}, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e);
traceLog("(bwc) invalidate tokens", cause);
if (isShardNotAvailableException(cause)) {
attemptCount.incrementAndGet();
indexBwcInvalidation(tokenIds, listener, attemptCount, expirationEpochMilli, previousResult);
} else {
listener.onFailure(e);
}
}),
client::bulk));
}
}
/**
* Performs the actual invalidation of a collection of tokens
*
@ -777,7 +669,7 @@ public final class TokenService {
} else {
SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setQuery(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
.filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
.setVersion(true)
.request();
@ -965,7 +857,7 @@ public final class TokenService {
} else {
final Instant now = clock.instant();
final BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
.filter(QueryBuilders.termQuery("access_token.realm", realmName))
.filter(QueryBuilders.boolQuery()
.should(QueryBuilders.boolQuery()
@ -1010,7 +902,7 @@ public final class TokenService {
} else {
final Instant now = clock.instant();
final BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.termQuery("doc_type", TOKEN_DOC_TYPE))
.filter(QueryBuilders.boolQuery()
.should(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("access_token.invalidated", false))
@ -1096,36 +988,19 @@ public final class TokenService {
}
}
private static String getInvalidatedTokenDocumentId(UserToken userToken) {
return getInvalidatedTokenDocumentId(userToken.getId());
}
private static String getInvalidatedTokenDocumentId(String id) {
return INVALIDATED_TOKEN_DOC_TYPE + "_" + id;
}
private static String getTokenDocumentId(UserToken userToken) {
return getTokenDocumentId(userToken.getId());
}
private static String getTokenDocumentId(String id) {
return "token_" + id;
return TOKEN_DOC_ID_PREFIX + id;
}
private static String getTokenIdFromDocumentId(String docId) {
if (docId.startsWith("token_") == false) {
if (docId.startsWith(TOKEN_DOC_ID_PREFIX) == false) {
throw new IllegalStateException("TokenDocument ID [" + docId + "] has unexpected value");
} else {
return docId.substring("token_".length());
}
}
private static String getTokenIdFromInvalidatedTokenDocumentId(String docId) {
final String invalidatedTokenDocPrefix = INVALIDATED_TOKEN_DOC_TYPE + "_";
if (docId.startsWith(invalidatedTokenDocPrefix) == false) {
throw new IllegalStateException("InvalidatedTokenDocument ID [" + docId + "] has unexpected value");
} else {
return docId.substring(invalidatedTokenDocPrefix.length());
return docId.substring(TOKEN_DOC_ID_PREFIX.length());
}
}
@ -1136,70 +1011,53 @@ public final class TokenService {
}
/**
* Checks if the token has been stored as a revoked token to ensure we do not allow tokens that
* have been explicitly cleared.
* Checks if the access token has been explicitly invalidated
*/
private void checkIfTokenIsRevoked(UserToken userToken, ActionListener<UserToken> listener) {
private void checkIfTokenIsValid(UserToken userToken, ActionListener<UserToken> listener) {
Instant currentTime = clock.instant();
if (currentTime.isAfter(userToken.getExpirationTime())) {
listener.onFailure(traceLog("validate token", userToken.getId(), expiredTokenException()));
}
if (securityIndex.indexExists() == false) {
// index doesn't exist so the token is considered valid. it is important to note that
// we do not use isAvailable as the lack of a shard being available is not equivalent
// to the index not existing in the case of revocation checking.
listener.onResponse(userToken);
// index doesn't exist so the token is considered invalid as we cannot verify its validity
logger.warn("failed to validate token [{}] since the security index doesn't exist", userToken.getId());
listener.onResponse(null);
} else {
securityIndex.checkIndexVersionThenExecute(listener::onFailure, () -> {
MultiGetRequest mGetRequest = client.prepareMultiGet()
.add(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, getInvalidatedTokenDocumentId(userToken))
.add(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, getTokenDocumentId(userToken))
.request();
final GetRequest getRequest = client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE,
getTokenDocumentId(userToken)).request();
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("check token state", userToken.getId(), ex));
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN,
mGetRequest,
new ActionListener<MultiGetResponse>() {
@Override
public void onResponse(MultiGetResponse response) {
MultiGetItemResponse[] itemResponse = response.getResponses();
if (itemResponse[0].isFailed()) {
onFailure(itemResponse[0].getFailure().getFailure());
} else if (itemResponse[0].getResponse().isExists()) {
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest,
ActionListener.<GetResponse>wrap(response -> {
if (response.isExists()) {
Map<String, Object> source = response.getSource();
Map<String, Object> accessTokenSource = (Map<String, Object>) source.get("access_token");
if (accessTokenSource == null) {
onFailure.accept(new IllegalStateException("token document is missing access_token field"));
} else {
Boolean invalidated = (Boolean) accessTokenSource.get("invalidated");
if (invalidated == null) {
onFailure.accept(new IllegalStateException("token document is missing invalidated field"));
} else if (invalidated) {
onFailure.accept(expiredTokenException());
} else if (itemResponse[1].isFailed()) {
onFailure(itemResponse[1].getFailure().getFailure());
} else if (itemResponse[1].getResponse().isExists()) {
Map<String, Object> source = itemResponse[1].getResponse().getSource();
Map<String, Object> accessTokenSource = (Map<String, Object>) source.get("access_token");
if (accessTokenSource == null) {
onFailure.accept(new IllegalStateException("token document is missing access_token field"));
} else {
Boolean invalidated = (Boolean) accessTokenSource.get("invalidated");
if (invalidated == null) {
onFailure.accept(new IllegalStateException("token document is missing invalidated field"));
} else if (invalidated) {
onFailure.accept(expiredTokenException());
} else {
listener.onResponse(userToken);
}
}
} else if (userToken.getVersion().onOrAfter(Version.V_6_2_0)) {
onFailure.accept(new IllegalStateException("token document is missing and must be present"));
} else {
listener.onResponse(userToken);
}
}
@Override
public void onFailure(Exception e) {
// if the index or the shard is not there / available we assume that
// the token is not valid
if (isShardNotAvailableException(e)) {
logger.warn("failed to get token [{}] since index is not available", userToken.getId());
listener.onResponse(null);
} else {
logger.error(new ParameterizedMessage("failed to get token [{}]", userToken.getId()), e);
listener.onFailure(e);
}
}
}, client::multiGet);
} else {
onFailure.accept(new IllegalStateException("token document is missing and must be present"));
}
}, e -> {
// if the index or the shard is not there / available we assume that
// the token is not valid
if (isShardNotAvailableException(e)) {
logger.warn("failed to get token [{}] since index is not available", userToken.getId());
listener.onResponse(null);
} else {
logger.error(new ParameterizedMessage("failed to get token [{}]", userToken.getId()), e);
listener.onFailure(e);
}
}), client::get);
});
}
}
@ -1209,10 +1067,6 @@ public final class TokenService {
return expirationDelay;
}
private Instant getExpirationTime() {
return getExpirationTime(clock.instant());
}
private Instant getExpirationTime(Instant now) {
return now.plusSeconds(expirationDelay.getSeconds());
}
@ -1258,11 +1112,7 @@ public final class TokenService {
new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion()));
StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) {
encryptedStreamOutput.setVersion(userToken.getVersion());
if (userToken.getVersion().onOrAfter(Version.V_6_2_0)) {
encryptedStreamOutput.writeString(userToken.getId());
} else {
userToken.writeTo(encryptedStreamOutput);
}
encryptedStreamOutput.writeString(userToken.getId());
encryptedStreamOutput.close();
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}

View File

@ -310,32 +310,26 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
assertThat(((TermQueryBuilder) filter1.get(1)).fieldName(), equalTo("refresh_token.token"));
assertThat(((TermQueryBuilder) filter1.get(1)).value(), equalTo(tokenToInvalidate1.v2()));
assertThat(bulkRequests.size(), equalTo(6)); // 4 updates (refresh-token + access-token) plus 2 indexes (bwc-invalidate * 2)
assertThat(bulkRequests.size(), equalTo(4)); // 4 updates (refresh-token + access-token)
// Invalidate refresh token 1
assertThat(bulkRequests.get(0).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(0).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
UpdateRequest updateRequest1 = (UpdateRequest) bulkRequests.get(0).requests().get(0);
assertThat(updateRequest1.toString().contains("refresh_token"), equalTo(true));
// BWC incalidate access token 1
assertThat(bulkRequests.get(1).requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate1.v1().getId()));
// Invalidate access token 1
assertThat(bulkRequests.get(2).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
UpdateRequest updateRequest2 = (UpdateRequest) bulkRequests.get(2).requests().get(0);
assertThat(bulkRequests.get(1).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
UpdateRequest updateRequest2 = (UpdateRequest) bulkRequests.get(1).requests().get(0);
assertThat(updateRequest2.toString().contains("access_token"), equalTo(true));
// Invalidate refresh token 2
assertThat(bulkRequests.get(2).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
UpdateRequest updateRequest3 = (UpdateRequest) bulkRequests.get(2).requests().get(0);
assertThat(updateRequest3.toString().contains("refresh_token"), equalTo(true));
// Invalidate access token 2
assertThat(bulkRequests.get(3).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(3).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
UpdateRequest updateRequest3 = (UpdateRequest) bulkRequests.get(3).requests().get(0);
assertThat(updateRequest3.toString().contains("refresh_token"), equalTo(true));
// BWC incalidate access token 2
assertThat(bulkRequests.get(4).requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequests.get(4).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate2.v1().getId()));
// Invalidate access token 2
assertThat(bulkRequests.get(5).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(5).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
UpdateRequest updateRequest4 = (UpdateRequest) bulkRequests.get(5).requests().get(0);
UpdateRequest updateRequest4 = (UpdateRequest) bulkRequests.get(3).requests().get(0);
assertThat(updateRequest4.toString().contains("access_token"), equalTo(true));
}

View File

@ -241,7 +241,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
final PlainActionFuture<Tuple<UserToken, String>> future = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, future, tokenMetaData, true);
final UserToken userToken = future.actionGet().v1();
mockGetTokenFromId(userToken, client);
mockGetTokenFromId(userToken, false, client);
final String tokenString = tokenService.getUserTokenString(userToken);
final SamlLogoutRequest request = new SamlLogoutRequest();
@ -256,17 +256,13 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
assertThat(indexRequest1, notNullValue());
assertThat(indexRequest1.id(), startsWith("token"));
assertThat(bulkRequests.size(), equalTo(2));
final BulkRequest bulkRequest1 = bulkRequests.get(0);
assertThat(bulkRequest1.requests().size(), equalTo(1));
assertThat(bulkRequest1.requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequest1.requests().get(0).id(), startsWith("invalidated-token_"));
assertThat(bulkRequests.size(), equalTo(1));
final BulkRequest bulkRequest2 = bulkRequests.get(1);
assertThat(bulkRequest2.requests().size(), equalTo(1));
assertThat(bulkRequest2.requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequest2.requests().get(0).id(), startsWith("token_"));
assertThat(bulkRequest2.requests().get(0).toString(), containsString("\"access_token\":{\"invalidated\":true"));
final BulkRequest bulkRequest = bulkRequests.get(0);
assertThat(bulkRequest.requests().size(), equalTo(1));
assertThat(bulkRequest.requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequest.requests().get(0).id(), startsWith("token_"));
assertThat(bulkRequest.requests().get(0).toString(), containsString("\"access_token\":{\"invalidated\":true"));
}
}

View File

@ -12,12 +12,8 @@ import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.get.GetAction;
import org.elasticsearch.action.get.GetRequestBuilder;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetAction;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.get.MultiGetRequest;
import org.elasticsearch.action.get.MultiGetRequestBuilder;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexRequestBuilder;
@ -28,7 +24,6 @@ import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
@ -88,7 +83,6 @@ import java.util.function.Consumer;
import static org.elasticsearch.test.SecurityTestsUtils.assertAuthenticationException;
import static org.elasticsearch.xpack.core.security.support.Exceptions.authenticationError;
import static org.elasticsearch.xpack.security.authc.TokenServiceTests.mockCheckTokenInvalidationFromId;
import static org.elasticsearch.xpack.security.authc.TokenServiceTests.mockGetTokenFromId;
import static org.hamcrest.Matchers.arrayContaining;
import static org.hamcrest.Matchers.contains;
@ -934,8 +928,7 @@ public class AuthenticationServiceTests extends ESTestCase {
}
String token = tokenService.getUserTokenString(tokenFuture.get().v1());
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
mockGetTokenFromId(tokenFuture.get().v1(), client);
mockCheckTokenInvalidationFromId(tokenFuture.get().v1(), client);
mockGetTokenFromId(tokenFuture.get().v1(), false, client);
when(securityIndex.isAvailable()).thenReturn(true);
when(securityIndex.indexExists()).thenReturn(true);
try (ThreadContext.StoredContext ignore = threadContext.stashContext()) {
@ -1017,32 +1010,7 @@ public class AuthenticationServiceTests extends ESTestCase {
tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap(), true);
}
String token = tokenService.getUserTokenString(tokenFuture.get().v1());
mockGetTokenFromId(tokenFuture.get().v1(), client);
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
doAnswer(invocationOnMock -> {
ActionListener<MultiGetResponse> listener = (ActionListener<MultiGetResponse>) invocationOnMock.getArguments()[1];
MultiGetResponse response = mock(MultiGetResponse.class);
MultiGetItemResponse[] responses = new MultiGetItemResponse[2];
when(response.getResponses()).thenReturn(responses);
final boolean newExpired = randomBoolean();
GetResponse oldGetResponse = mock(GetResponse.class);
when(oldGetResponse.isExists()).thenReturn(newExpired == false);
responses[0] = new MultiGetItemResponse(oldGetResponse, null);
GetResponse getResponse = mock(GetResponse.class);
responses[1] = new MultiGetItemResponse(getResponse, null);
when(getResponse.isExists()).thenReturn(newExpired);
if (newExpired) {
Map<String, Object> source = MapBuilder.<String, Object>newMapBuilder()
.put("access_token", Collections.singletonMap("invalidated", true))
.immutableMap();
when(getResponse.getSource()).thenReturn(source);
}
listener.onResponse(response);
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
mockGetTokenFromId(tokenFuture.get().v1(), true, client);
doAnswer(invocationOnMock -> {
((Runnable) invocationOnMock.getArguments()[1]).run();
return null;

View File

@ -151,7 +151,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
assertBusy(() -> {
SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setSource(SearchSourceBuilder.searchSource()
.query(QueryBuilders.termQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE)))
.query(QueryBuilders.termQuery("doc_type", "token")))
.setSize(1)
.setTerminateAfter(1)
.get();
@ -159,11 +159,11 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
docId.set(searchResponse.getHits().getAt(0).getId());
});
// hack doc to modify the time to the day before
// hack doc to modify the creation time to the day before
Instant dayBefore = created.minus(1L, ChronoUnit.DAYS);
assertTrue(Instant.now().isAfter(dayBefore));
client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, "doc", docId.get())
.setDoc("expiration_time", dayBefore.toEpochMilli())
.setDoc("creation_time", dayBefore.toEpochMilli())
.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE)
.get();
@ -183,8 +183,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
client.admin().indices().prepareRefresh(SecurityIndexManager.SECURITY_INDEX_NAME).get();
SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setSource(SearchSourceBuilder.searchSource()
.query(QueryBuilders.termQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE)))
.setSize(0)
.query(QueryBuilders.termQuery("doc_type", "token")))
.setTerminateAfter(1)
.get();
assertThat(searchResponse.getHits().getTotalHits().value, equalTo(0L));

View File

@ -12,11 +12,6 @@ import org.elasticsearch.action.get.GetAction;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetRequestBuilder;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetAction;
import org.elasticsearch.action.get.MultiGetItemResponse;
import org.elasticsearch.action.get.MultiGetRequest;
import org.elasticsearch.action.get.MultiGetRequestBuilder;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexRequestBuilder;
@ -27,7 +22,6 @@ import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.MapBuilder;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
@ -104,23 +98,6 @@ public class TokenServiceTests extends ESTestCase {
.setId((String) invocationOnMock.getArguments()[2]);
return builder;
}).when(client).prepareGet(anyString(), anyString(), anyString());
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
doAnswer(invocationOnMock -> {
ActionListener<MultiGetResponse> listener = (ActionListener<MultiGetResponse>) invocationOnMock.getArguments()[1];
MultiGetResponse response = mock(MultiGetResponse.class);
MultiGetItemResponse[] responses = new MultiGetItemResponse[2];
when(response.getResponses()).thenReturn(responses);
GetResponse oldGetResponse = mock(GetResponse.class);
when(oldGetResponse.isExists()).thenReturn(false);
responses[0] = new MultiGetItemResponse(oldGetResponse, null);
GetResponse getResponse = mock(GetResponse.class);
responses[1] = new MultiGetItemResponse(getResponse, null);
when(getResponse.isExists()).thenReturn(false);
listener.onResponse(response);
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
when(client.prepareIndex(any(String.class), any(String.class), any(String.class)))
.thenReturn(new IndexRequestBuilder(client, IndexAction.INSTANCE));
when(client.prepareUpdate(any(String.class), any(String.class), any(String.class)))
@ -168,8 +145,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", randomFrom("Bearer ", "BEARER ", "bearer ") + tokenService.getUserTokenString(token));
@ -215,8 +191,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -244,7 +219,7 @@ public class TokenServiceTests extends ESTestCase {
requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken));
mockGetTokenFromId(newToken);
mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -275,8 +250,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -306,8 +280,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -351,7 +324,7 @@ public class TokenServiceTests extends ESTestCase {
requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(newToken));
mockGetTokenFromId(newToken);
mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
tokenService.getAndValidateToken(requestContext, future);
@ -368,8 +341,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -413,33 +385,10 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
doAnswer(invocationOnMock -> {
ActionListener<MultiGetResponse> listener = (ActionListener<MultiGetResponse>) invocationOnMock.getArguments()[1];
MultiGetResponse response = mock(MultiGetResponse.class);
MultiGetItemResponse[] responses = new MultiGetItemResponse[2];
when(response.getResponses()).thenReturn(responses);
final boolean newExpired = randomBoolean();
GetResponse oldGetResponse = mock(GetResponse.class);
when(oldGetResponse.isExists()).thenReturn(newExpired == false);
responses[0] = new MultiGetItemResponse(oldGetResponse, null);
GetResponse getResponse = mock(GetResponse.class);
responses[1] = new MultiGetItemResponse(getResponse, null);
when(getResponse.isExists()).thenReturn(newExpired);
if (newExpired) {
Map<String, Object> source = MapBuilder.<String, Object>newMapBuilder()
.put("access_token", Collections.singletonMap("invalidated", true))
.immutableMap();
when(getResponse.getSource()).thenReturn(source);
}
listener.onResponse(response);
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
mockGetTokenFromId(token, true);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
mockGetTokenFromId(token);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -466,8 +415,7 @@ public class TokenServiceTests extends ESTestCase {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
mockGetTokenFromId(token);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -577,7 +525,7 @@ public class TokenServiceTests extends ESTestCase {
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
final UserToken token = tokenFuture.get().v1();
assertNotNull(token);
mockGetTokenFromId(token);
//mockGetTokenFromId(token, false);
ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token));
@ -586,7 +534,7 @@ public class TokenServiceTests extends ESTestCase {
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
listener.onFailure(new NoShardAvailableActionException(new ShardId(new Index("foo", "uuid"), 0), "shard oh shard"));
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
}).when(client).get(any(GetRequest.class), any(ActionListener.class));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
@ -606,19 +554,19 @@ public class TokenServiceTests extends ESTestCase {
when(securityIndex.isAvailable()).thenReturn(true);
when(securityIndex.indexExists()).thenReturn(true);
mockCheckTokenInvalidationFromId(token);
mockGetTokenFromId(token, false);
future = new PlainActionFuture<>();
tokenService.getAndValidateToken(requestContext, future);
assertEquals(token.getAuthentication(), future.get().getAuthentication());
}
}
public void testGetAuthenticationWorksWithExpiredToken() throws Exception {
public void testGetAuthenticationWorksWithExpiredUserToken() throws Exception {
TokenService tokenService =
new TokenService(tokenServiceEnabledSettings, Clock.systemUTC(), client, securityIndex, clusterService);
Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null);
UserToken expired = new UserToken(authentication, Instant.now().minus(3L, ChronoUnit.DAYS));
mockGetTokenFromId(expired);
mockGetTokenFromId(expired, false);
String userTokenString = tokenService.getUserTokenString(expired);
PlainActionFuture<Tuple<Authentication, Map<String, Object>>> authFuture = new PlainActionFuture<>();
tokenService.getAuthenticationAndMetaData(userTokenString, authFuture);
@ -626,62 +574,30 @@ public class TokenServiceTests extends ESTestCase {
assertEquals(authentication, retrievedAuth);
}
private void mockGetTokenFromId(UserToken userToken) {
mockGetTokenFromId(userToken, client);
private void mockGetTokenFromId(UserToken userToken, boolean isExpired) {
mockGetTokenFromId(userToken, isExpired, client);
}
public static void mockGetTokenFromId(UserToken userToken, Client client) {
public static void mockGetTokenFromId(UserToken userToken, boolean isExpired, Client client) {
doAnswer(invocationOnMock -> {
GetRequest getRequest = (GetRequest) invocationOnMock.getArguments()[0];
ActionListener<GetResponse> getResponseListener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
GetResponse getResponse = mock(GetResponse.class);
if (userToken.getId().equals(getRequest.id().replace("token_", ""))) {
when(getResponse.isExists()).thenReturn(true);
Map<String, Object> sourceMap = new HashMap<>();
try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
userToken.toXContent(builder, ToXContent.EMPTY_PARAMS);
sourceMap.put("access_token",
Collections.singletonMap("user_token",
XContentHelper.convertToMap(XContentType.JSON.xContent(), Strings.toString(builder), false)));
}
when(getResponse.getSource()).thenReturn(sourceMap);
}
getResponseListener.onResponse(getResponse);
return Void.TYPE;
}).when(client).get(any(GetRequest.class), any(ActionListener.class));
}
private void mockCheckTokenInvalidationFromId(UserToken userToken) {
mockCheckTokenInvalidationFromId(userToken, client);
}
public static void mockCheckTokenInvalidationFromId(UserToken userToken, Client client) {
doAnswer(invocationOnMock -> {
MultiGetRequest request = (MultiGetRequest) invocationOnMock.getArguments()[0];
ActionListener<MultiGetResponse> listener = (ActionListener<MultiGetResponse>) invocationOnMock.getArguments()[1];
MultiGetResponse response = mock(MultiGetResponse.class);
MultiGetItemResponse[] responses = new MultiGetItemResponse[2];
when(response.getResponses()).thenReturn(responses);
GetResponse legacyResponse = mock(GetResponse.class);
responses[0] = new MultiGetItemResponse(legacyResponse, null);
when(legacyResponse.isExists()).thenReturn(false);
GetResponse tokenResponse = mock(GetResponse.class);
if (userToken.getId().equals(request.getItems().get(1).id().replace("token_", ""))) {
when(tokenResponse.isExists()).thenReturn(true);
GetRequest request = (GetRequest) invocationOnMock.getArguments()[0];
ActionListener<GetResponse> listener = (ActionListener<GetResponse>) invocationOnMock.getArguments()[1];
GetResponse response = mock(GetResponse.class);
if (userToken.getId().equals(request.id().replace("token_", ""))) {
when(response.isExists()).thenReturn(true);
Map<String, Object> sourceMap = new HashMap<>();
try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) {
userToken.toXContent(builder, ToXContent.EMPTY_PARAMS);
Map<String, Object> accessTokenMap = new HashMap<>();
accessTokenMap.put("user_token",
XContentHelper.convertToMap(XContentType.JSON.xContent(), Strings.toString(builder), false));
accessTokenMap.put("invalidated", false);
accessTokenMap.put("invalidated", isExpired);
sourceMap.put("access_token", accessTokenMap);
}
when(tokenResponse.getSource()).thenReturn(sourceMap);
when(response.getSource()).thenReturn(sourceMap);
}
responses[1] = new MultiGetItemResponse(tokenResponse, null);
listener.onResponse(response);
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
}).when(client).get(any(GetRequest.class), any(ActionListener.class));
}
}

View File

@ -32,8 +32,7 @@ public class TokensInvalidationResultTests extends ESTestCase {
result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo(
"{\"created\":false," +
"\"invalidated_tokens\":2," +
"{\"invalidated_tokens\":2," +
"\"previously_invalidated_tokens\":2," +
"\"error_count\":2," +
"\"error_details\":[" +
@ -64,8 +63,7 @@ public class TokensInvalidationResultTests extends ESTestCase {
result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo(
"{\"created\":true," +
"\"invalidated_tokens\":2," +
"{\"invalidated_tokens\":2," +
"\"previously_invalidated_tokens\":0," +
"\"error_count\":0" +
"}"));

View File

@ -79,7 +79,6 @@ teardown:
body:
token: $token
- match: { created: true}
- match: { invalidated_tokens: 1 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }
@ -120,7 +119,6 @@ teardown:
body:
username: "token_user"
- match: { created: true}
- match: { invalidated_tokens: 2 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }
@ -162,7 +160,6 @@ teardown:
body:
realm_name: "default_native"
- match: { created: true}
- match: { invalidated_tokens: 2 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }