From 7b9ca62174216097493e83d1f127c1098c4f6737 Mon Sep 17 00:00:00 2001 From: Ioannis Kakavas Date: Tue, 18 Dec 2018 10:05:50 +0200 Subject: [PATCH] Enhance Invalidate Token API (#35388) This change: - Adds functionality to invalidate all (refresh+access) tokens for all users of a realm - Adds functionality to invalidate all (refresh+access)tokens for a user in all realms - Adds functionality to invalidate all (refresh+access) tokens for a user in a specific realm - Changes the response format for the invalidate token API to contain information about the number of the invalidated tokens and possible errors that were encountered. - Updates the API Documentation After back-porting to 6.x, the `created` field will be removed from master as a field in the response Resolves: #35115 Relates: #34556 --- .../SecurityDocumentationIT.java | 1 + .../security/invalidate-token.asciidoc | 2 +- .../security/invalidate-tokens.asciidoc | 90 ++- .../action/token/InvalidateTokenAction.java | 2 +- .../action/token/InvalidateTokenRequest.java | 133 ++++- .../token/InvalidateTokenRequestBuilder.java | 16 + .../action/token/InvalidateTokenResponse.java | 66 ++- .../support/TokensInvalidationResult.java | 113 ++++ .../core/security/client/SecurityClient.java | 4 + .../action/token/CreateTokenRequestTests.java | 1 - .../token/InvalidateTokenRequestTests.java | 82 +++ .../token/InvalidateTokenResponseTests.java | 141 +++++ .../TransportSamlInvalidateSessionAction.java | 46 +- .../saml/TransportSamlLogoutAction.java | 3 +- .../token/TransportInvalidateTokenAction.java | 11 +- .../xpack/security/authc/TokenService.java | 516 ++++++++++++------ .../oauth2/RestInvalidateTokenAction.java | 66 +-- ...sportSamlInvalidateSessionActionTests.java | 72 ++- .../saml/TransportSamlLogoutActionTests.java | 49 +- .../security/authc/TokenAuthIntegTests.java | 96 +++- .../security/authc/TokenServiceTests.java | 3 +- .../TokensInvalidationResultTests.java | 74 +++ .../RestInvalidateTokenActionTests.java | 61 +++ .../rest-api-spec/test/token/10_basic.yml | 90 ++- 24 files changed, 1429 insertions(+), 309 deletions(-) create mode 100644 x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestTests.java create mode 100644 x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java create mode 100644 x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java create mode 100644 x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenActionTests.java diff --git a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/SecurityDocumentationIT.java b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/SecurityDocumentationIT.java index 8bd285cd31f..6cd56774086 100644 --- a/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/SecurityDocumentationIT.java +++ b/client/rest-high-level/src/test/java/org/elasticsearch/client/documentation/SecurityDocumentationIT.java @@ -1317,6 +1317,7 @@ public class SecurityDocumentationIT extends ESRestHighLevelClientTestCase { } } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/pull/36362") public void testInvalidateToken() throws Exception { RestHighLevelClient client = highLevelClient(); diff --git a/docs/java-rest/high-level/security/invalidate-token.asciidoc b/docs/java-rest/high-level/security/invalidate-token.asciidoc index ecb3fedb56f..65e0f15bd86 100644 --- a/docs/java-rest/high-level/security/invalidate-token.asciidoc +++ b/docs/java-rest/high-level/security/invalidate-token.asciidoc @@ -36,4 +36,4 @@ The returned +{response}+ contains a single property: ["source","java",subs="attributes,callouts,macros"] -------------------------------------------------- include-tagged::{doc-tests-file}[{api}-response] --------------------------------------------------- +-------------------------------------------------- \ No newline at end of file diff --git a/x-pack/docs/en/rest-api/security/invalidate-tokens.asciidoc b/x-pack/docs/en/rest-api/security/invalidate-tokens.asciidoc index 540f5866825..18c88f7addd 100644 --- a/x-pack/docs/en/rest-api/security/invalidate-tokens.asciidoc +++ b/x-pack/docs/en/rest-api/security/invalidate-tokens.asciidoc @@ -2,7 +2,7 @@ [[security-api-invalidate-token]] === Invalidate token API -Invalidates an access token or a refresh token. +Invalidates one or more access tokens or refresh tokens. ==== Request @@ -19,21 +19,31 @@ can no longer be used. That time period is defined by the The refresh tokens returned by the <> are only valid for 24 hours. They can also be used exactly once. -If you want to invalidate an access or refresh token immediately, use this invalidate token API. +If you want to invalidate one or more access or refresh tokens immediately, use this invalidate token API. ==== Request Body The following parameters can be specified in the body of a DELETE request and -pertain to invalidating a token: +pertain to invalidating tokens: `token` (optional):: -(string) An access token. This parameter cannot be used when `refresh_token` is used. +(string) An access token. This parameter cannot be used any of `refresh_token`, `realm_name` or + `username` are used. `refresh_token` (optional):: -(string) A refresh token. This parameter cannot be used when `token` is used. +(string) A refresh token. This parameter cannot be used any of `refresh_token`, `realm_name` or + `username` are used. -NOTE: One of `token` or `refresh_token` parameters is required. +`realm_name` (optional):: +(string) The name of an authentication realm. This parameter cannot be used with either `refresh_token` or `token`. + +`username` (optional):: +(string) The username of a user. This parameter cannot be used with either `refresh_token` or `token` + +NOTE: While all parameters are optional, at least one of them is required. More specifically, either one of `token` +or `refresh_token` parameters is required. If none of these two are specified, then `realm_name` and/or `username` +need to be specified. ==== Examples @@ -59,15 +69,75 @@ DELETE /_security/oauth2/token -------------------------------------------------- // NOTCONSOLE -A successful call returns a JSON structure that indicates whether the token -has already been invalidated. +The following example invalidates all access tokens and refresh tokens for the `saml1` realm immediately: [source,js] -------------------------------------------------- +DELETE /_xpack/security/oauth2/token { - "created" : true <1> + "realm_name" : "saml1" } -------------------------------------------------- // NOTCONSOLE -<1> When a token has already been invalidated, `created` is set to false. +The following example invalidates all access tokens and refresh tokens for the user `myuser` in all realms immediately: + +[source,js] +-------------------------------------------------- +DELETE /_xpack/security/oauth2/token +{ + "username" : "myuser" +} +-------------------------------------------------- +// NOTCONSOLE + +Finally, the following example invalidates all access tokens and refresh tokens for the user `myuser` in + the `saml1` realm immediately: + +[source,js] +-------------------------------------------------- +DELETE /_xpack/security/oauth2/token +{ + "username" : "myuser", + "realm_name" : "saml1" +} +-------------------------------------------------- +// NOTCONSOLE + +A successful call returns a JSON structure that contains the number of tokens that were invalidated, the number +of tokens that had already been invalidated, and potentially a list of errors encountered while invalidating +specific tokens. + +[source,js] +-------------------------------------------------- +{ + "invalidated_tokens":9, <1> + "previously_invalidated_tokens":15, <2> + "error_count":2, <3> + "error_details":[ <4> + { + "type":"exception", + "reason":"Elasticsearch exception [type=exception, reason=foo]", + "caused_by":{ + "type":"exception", + "reason":"Elasticsearch exception [type=illegal_argument_exception, reason=bar]" + } + }, + { + "type":"exception", + "reason":"Elasticsearch exception [type=exception, reason=boo]", + "caused_by":{ + "type":"exception", + "reason":"Elasticsearch exception [type=illegal_argument_exception, reason=far]" + } + } + ] +} +-------------------------------------------------- +// NOTCONSOLE + +<1> The number of the tokens that were invalidated as part of this request. +<2> The number of tokens that were already invalidated. +<3> The number of errors that were encountered when invalidating the tokens. +<4> Details about these errors. This field is not present in the response when + `error_count` is 0. diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenAction.java index 679ee0756f6..57bd5bd35dd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenAction.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.security.action.token; import org.elasticsearch.action.Action; /** - * Action for invalidating a given token + * Action for invalidating one or more tokens */ public final class InvalidateTokenAction extends Action { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequest.java index 7a8372fe456..de3b73ec4af 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequest.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequest.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.security.action.token; import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -22,31 +23,81 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; public final class InvalidateTokenRequest extends ActionRequest { public enum Type { - ACCESS_TOKEN, - REFRESH_TOKEN + ACCESS_TOKEN("token"), + REFRESH_TOKEN("refresh_token"); + + private final String value; + + Type(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + public static Type fromString(String tokenType) { + if (tokenType != null) { + for (Type type : values()) { + if (type.getValue().equals(tokenType)) { + return type; + } + } + } + return null; + } } private String tokenString; private Type tokenType; + private String realmName; + private String userName; public InvalidateTokenRequest() {} /** - * @param tokenString the string representation of the token + * @param tokenString the string representation of the token to be invalidated + * @param tokenType the type of the token to be invalidated + * @param realmName the name of the realm for which all tokens will be invalidated + * @param userName the principal of the user for which all tokens will be invalidated */ - public InvalidateTokenRequest(String tokenString, Type type) { + public InvalidateTokenRequest(@Nullable String tokenString, @Nullable String tokenType, + @Nullable String realmName, @Nullable String userName) { this.tokenString = tokenString; - this.tokenType = type; + this.tokenType = Type.fromString(tokenType); + this.realmName = realmName; + this.userName = userName; + } + + /** + * @param tokenString the string representation of the token to be invalidated + * @param tokenType the type of the token to be invalidated + */ + public InvalidateTokenRequest(String tokenString, String tokenType) { + this.tokenString = tokenString; + this.tokenType = Type.fromString(tokenType); + this.realmName = null; + this.userName = null; } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if (Strings.isNullOrEmpty(tokenString)) { - validationException = addValidationError("token string must be provided", null); - } - if (tokenType == null) { - validationException = addValidationError("token type must be provided", validationException); + if (Strings.hasText(realmName) || Strings.hasText(userName)) { + if (Strings.hasText(tokenString)) { + validationException = + addValidationError("token string must not be provided when realm name or username is specified", null); + } + if (tokenType != null) { + validationException = + addValidationError("token type must not be provided when realm name or username is specified", validationException); + } + } else if (Strings.isNullOrEmpty(tokenString)) { + validationException = + addValidationError("token string must be provided when not specifying a realm name or a username", null); + } else if (tokenType == null) { + validationException = + addValidationError("token type must be provided when a token string is specified", null); } return validationException; } @@ -67,26 +118,76 @@ public final class InvalidateTokenRequest extends ActionRequest { this.tokenType = tokenType; } + public String getRealmName() { + return realmName; + } + + public void setRealmName(String realmName) { + this.realmName = realmName; + } + + public String getUserName() { + return userName; + } + + public void setUserName(String userName) { + this.userName = userName; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeString(tokenString); + if (out.getVersion().before(Version.V_7_0_0)) { + if (Strings.isNullOrEmpty(tokenString)) { + throw new IllegalArgumentException("token is required for versions < v6.6.0"); + } + out.writeString(tokenString); + } else { + out.writeOptionalString(tokenString); + } if (out.getVersion().onOrAfter(Version.V_6_2_0)) { - out.writeVInt(tokenType.ordinal()); + if (out.getVersion().before(Version.V_7_0_0)) { + if (tokenType == null) { + throw new IllegalArgumentException("token type is not optional for versions > v6.2.0 and < v6.6.0"); + } + out.writeVInt(tokenType.ordinal()); + } else { + out.writeOptionalVInt(tokenType == null ? null : tokenType.ordinal()); + } } else if (tokenType == Type.REFRESH_TOKEN) { - throw new IllegalArgumentException("refresh token invalidation cannot be serialized with version [" + out.getVersion() + - "]"); + throw new IllegalArgumentException("refresh token invalidation cannot be serialized with version [" + out.getVersion() + "]"); + } + if (out.getVersion().onOrAfter(Version.V_7_0_0)) { + out.writeOptionalString(realmName); + out.writeOptionalString(userName); + } else if (realmName != null || userName != null) { + throw new IllegalArgumentException( + "realm or user token invalidation cannot be serialized with version [" + out.getVersion() + "]"); } } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); - tokenString = in.readString(); + if (in.getVersion().before(Version.V_7_0_0)) { + tokenString = in.readString(); + } else { + tokenString = in.readOptionalString(); + } if (in.getVersion().onOrAfter(Version.V_6_2_0)) { - tokenType = Type.values()[in.readVInt()]; + if (in.getVersion().before(Version.V_7_0_0)) { + int type = in.readVInt(); + tokenType = Type.values()[type]; + } else { + Integer type = in.readOptionalVInt(); + tokenType = type == null ? null : Type.values()[type]; + } } else { tokenType = Type.ACCESS_TOKEN; } + if (in.getVersion().onOrAfter(Version.V_7_0_0)) { + realmName = in.readOptionalString(); + userName = in.readOptionalString(); + } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestBuilder.java index f77f6c65332..0b454905cfa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestBuilder.java @@ -34,4 +34,20 @@ public final class InvalidateTokenRequestBuilder request.setTokenType(type); return this; } + + /** + * Sets the name of the realm for which all tokens should be invalidated + */ + public InvalidateTokenRequestBuilder setRealmName(String realmName) { + request.setRealmName(realmName); + return this; + } + + /** + * Sets the username for which all tokens should be invalidated + */ + public InvalidateTokenRequestBuilder setUserName(String username) { + request.setUserName(username); + return this; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponse.java index cebb005b272..886caeac370 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponse.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponse.java @@ -5,41 +5,83 @@ */ package org.elasticsearch.xpack.core.security.action.token; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Objects; /** - * Response for a invalidation of a token. + * Response for a invalidation of one or multiple tokens. */ -public final class InvalidateTokenResponse extends ActionResponse { +public final class InvalidateTokenResponse extends ActionResponse implements ToXContent { - private boolean created; + private TokensInvalidationResult result; public InvalidateTokenResponse() {} - public InvalidateTokenResponse(boolean created) { - this.created = created; + public InvalidateTokenResponse(TokensInvalidationResult result) { + this.result = result; } - /** - * If the token is already invalidated then created will be false - */ - public boolean isCreated() { - return created; + public TokensInvalidationResult getResult() { + return result; + } + + private boolean isCreated() { + return result.getInvalidatedTokens().size() > 0 + && result.getPreviouslyInvalidatedTokens().isEmpty() + && result.getErrors().isEmpty(); } @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); - out.writeBoolean(created); + if (out.getVersion().before(Version.V_7_0_0)) { + out.writeBoolean(isCreated()); + } else { + result.writeTo(out); + } } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); - created = in.readBoolean(); + if (in.getVersion().before(Version.V_7_0_0)) { + final boolean created = in.readBoolean(); + if (created) { + result = new TokensInvalidationResult(Arrays.asList(""), Collections.emptyList(), Collections.emptyList(), 0); + } else { + result = new TokensInvalidationResult(Collections.emptyList(), Arrays.asList(""), Collections.emptyList(), 0); + } + } else { + result = new TokensInvalidationResult(in); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + result.toXContent(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + InvalidateTokenResponse that = (InvalidateTokenResponse) o; + return Objects.equals(result, that.result); + } + + @Override + public int hashCode() { + return Objects.hash(result); } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java new file mode 100644 index 00000000000..cfa83b63ed5 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/support/TokensInvalidationResult.java @@ -0,0 +1,113 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.core.security.authc.support; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Nullable; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * The result of attempting to invalidate one or multiple tokens. The result contains information about: + *
    + *
  • how many of the tokens were actually invalidated
  • + *
  • how many tokens are not invalidated in this request because they were already invalidated
  • + *
  • how many errors were encountered while invalidating tokens and the error details
  • + *
+ */ +public class TokensInvalidationResult implements ToXContentObject, Writeable { + + private final List invalidatedTokens; + private final List previouslyInvalidatedTokens; + private final List errors; + private final int attemptCount; + + public TokensInvalidationResult(List invalidatedTokens, List previouslyInvalidatedTokens, + @Nullable List errors, int attemptCount) { + Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided"); + this.invalidatedTokens = invalidatedTokens; + Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided"); + this.previouslyInvalidatedTokens = previouslyInvalidatedTokens; + if (null != errors) { + this.errors = errors; + } else { + this.errors = Collections.emptyList(); + } + this.attemptCount = attemptCount; + } + + public TokensInvalidationResult(StreamInput in) throws IOException { + this.invalidatedTokens = in.readList(StreamInput::readString); + this.previouslyInvalidatedTokens = in.readList(StreamInput::readString); + this.errors = in.readList(StreamInput::readException); + this.attemptCount = in.readVInt(); + } + + public static TokensInvalidationResult emptyResult() { + return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0); + } + + + public List getInvalidatedTokens() { + return invalidatedTokens; + } + + public List getPreviouslyInvalidatedTokens() { + return previouslyInvalidatedTokens; + } + + public List getErrors() { + return errors; + } + + public int getAttemptCount() { + return attemptCount; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject() + //Remove created after PR is backported to 6.x + .field("created", isCreated()) + .field("invalidated_tokens", invalidatedTokens.size()) + .field("previously_invalidated_tokens", previouslyInvalidatedTokens.size()) + .field("error_count", errors.size()); + if (errors.isEmpty() == false) { + builder.field("error_details"); + builder.startArray(); + for (ElasticsearchException e : errors) { + builder.startObject(); + ElasticsearchException.generateThrowableXContent(builder, params, e); + builder.endObject(); + } + builder.endArray(); + } + return builder.endObject(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringList(invalidatedTokens); + out.writeStringList(previouslyInvalidatedTokens); + out.writeCollection(errors, StreamOutput::writeException); + out.writeVInt(attemptCount); + } + + private boolean isCreated() { + return this.getInvalidatedTokens().size() > 0 + && this.getPreviouslyInvalidatedTokens().isEmpty() + && this.getErrors().isEmpty(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/client/SecurityClient.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/client/SecurityClient.java index ef59f870c68..a7faf4d2231 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/client/SecurityClient.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/client/SecurityClient.java @@ -326,6 +326,10 @@ public class SecurityClient { return new InvalidateTokenRequestBuilder(client).setTokenString(token); } + public InvalidateTokenRequestBuilder prepareInvalidateToken() { + return new InvalidateTokenRequestBuilder(client); + } + public void invalidateToken(InvalidateTokenRequest request, ActionListener listener) { client.execute(InvalidateTokenAction.INSTANCE, request, listener); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/CreateTokenRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/CreateTokenRequestTests.java index bd23198e8ea..2d8782f0111 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/CreateTokenRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/CreateTokenRequestTests.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.security.action.token; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasItem; diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestTests.java new file mode 100644 index 00000000000..3fd7eb7da46 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenRequestTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.security.action.token; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.containsString; + +public class InvalidateTokenRequestTests extends ESTestCase { + + public void testValidation() { + InvalidateTokenRequest request = new InvalidateTokenRequest(); + ActionRequestValidationException ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), containsString("token string must be provided when not specifying a realm")); + + request = new InvalidateTokenRequest(randomAlphaOfLength(12), randomFrom("", null)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), containsString("token type must be provided when a token string is specified")); + + request = new InvalidateTokenRequest(randomFrom("", null), "access_token"); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), containsString("token string must be provided when not specifying a realm")); + + request = new InvalidateTokenRequest(randomFrom("", null), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8)); + ve = request.validate(); + assertNull(ve); + + request = + new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), + containsString("token string must not be provided when realm name or username is specified")); + + request = new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("token", "refresh_token"), + randomAlphaOfLength(4), randomAlphaOfLength(8)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(2, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), + containsString("token string must not be provided when realm name or username is specified")); + assertThat(ve.validationErrors().get(1), + containsString("token type must not be provided when realm name or username is specified")); + + request = + new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), + containsString("token string must not be provided when realm name or username is specified")); + + request = + new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("token", "refresh_token"), randomFrom("", null), + randomAlphaOfLength(8)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(2, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), + containsString("token string must not be provided when realm name or username is specified")); + assertThat(ve.validationErrors().get(1), + containsString("token type must not be provided when realm name or username is specified")); + + request = new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomFrom("", null), randomAlphaOfLength(8)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), + containsString("token string must not be provided when realm name or username is specified")); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java new file mode 100644 index 00000000000..1a59971ff9c --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/action/token/InvalidateTokenResponseTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.core.security.action.token; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.Version; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.VersionUtils; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; + +public class InvalidateTokenResponseTests extends ESTestCase { + + public void testSerialization() throws IOException { + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), + Arrays.asList(generateRandomStringArray(20, 15, false)), + Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), + randomIntBetween(0, 5)); + InvalidateTokenResponse response = new InvalidateTokenResponse(result); + try (BytesStreamOutput output = new BytesStreamOutput()) { + response.writeTo(output); + try (StreamInput input = output.bytes().streamInput()) { + InvalidateTokenResponse serialized = new InvalidateTokenResponse(); + serialized.readFrom(input); + assertThat(serialized.getResult().getInvalidatedTokens(), equalTo(response.getResult().getInvalidatedTokens())); + assertThat(serialized.getResult().getPreviouslyInvalidatedTokens(), + equalTo(response.getResult().getPreviouslyInvalidatedTokens())); + assertThat(serialized.getResult().getErrors().size(), equalTo(response.getResult().getErrors().size())); + assertThat(serialized.getResult().getErrors().get(0).toString(), containsString("this is an error message")); + assertThat(serialized.getResult().getErrors().get(1).toString(), containsString("this is an error message2")); + } + } + + result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)), + Arrays.asList(generateRandomStringArray(20, 15, false)), + Collections.emptyList(), randomIntBetween(0, 5)); + response = new InvalidateTokenResponse(result); + try (BytesStreamOutput output = new BytesStreamOutput()) { + response.writeTo(output); + try (StreamInput input = output.bytes().streamInput()) { + InvalidateTokenResponse serialized = new InvalidateTokenResponse(); + serialized.readFrom(input); + assertThat(serialized.getResult().getInvalidatedTokens(), equalTo(response.getResult().getInvalidatedTokens())); + assertThat(serialized.getResult().getPreviouslyInvalidatedTokens(), + equalTo(response.getResult().getPreviouslyInvalidatedTokens())); + assertThat(serialized.getResult().getErrors().size(), equalTo(response.getResult().getErrors().size())); + } + } + } + + public void testSerializationToPre66Version() throws IOException{ + final Version version = VersionUtils.randomVersionBetween(random(), Version.V_6_2_0, Version.V_6_5_1); + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)), + Arrays.asList(generateRandomStringArray(20, 15, false, false)), + Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), + randomIntBetween(0, 5)); + InvalidateTokenResponse response = new InvalidateTokenResponse(result); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + response.writeTo(output); + try (StreamInput input = output.bytes().streamInput()) { + // False as we have errors and previously invalidated tokens + assertThat(input.readBoolean(), equalTo(false)); + } + } + + result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)), + Arrays.asList(generateRandomStringArray(20, 15, false, false)), + Collections.emptyList(), randomIntBetween(0, 5)); + response = new InvalidateTokenResponse(result); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + response.writeTo(output); + try (StreamInput input = output.bytes().streamInput()) { + // False as we have previously invalidated tokens + assertThat(input.readBoolean(), equalTo(false)); + } + } + + result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)), + Collections.emptyList(), Collections.emptyList(), randomIntBetween(0, 5)); + response = new InvalidateTokenResponse(result); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + response.writeTo(output); + try (StreamInput input = output.bytes().streamInput()) { + assertThat(input.readBoolean(), equalTo(true)); + } + } + } + + public void testToXContent() throws IOException { + List invalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); + List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false)); + TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens, + Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")), + new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))), + randomIntBetween(0, 5)); + InvalidateTokenResponse response = new InvalidateTokenResponse(result); + XContentBuilder builder = XContentFactory.jsonBuilder(); + response.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertThat(Strings.toString(builder), + equalTo("{\"created\":false," + + "\"invalidated_tokens\":" + invalidatedTokens.size() + "," + + "\"previously_invalidated_tokens\":" + previouslyInvalidatedTokens.size() + "," + + "\"error_count\":2," + + "\"error_details\":[" + + "{\"type\":\"exception\"," + + "\"reason\":\"foo\"," + + "\"caused_by\":{" + + "\"type\":\"illegal_argument_exception\"," + + "\"reason\":\"this is an error message\"}" + + "}," + + "{\"type\":\"exception\"," + + "\"reason\":\"bar\"," + + "\"caused_by\":" + + "{\"type\":\"illegal_argument_exception\"," + + "\"reason\":\"this is an error message2\"}" + + "}" + + "]" + + "}")); + } +} diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java index f0e6bf2c990..8c35df01ed9 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionAction; import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionRequest; import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionResponse; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.TokenService; import org.elasticsearch.xpack.security.authc.UserToken; @@ -27,12 +28,11 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect; import org.elasticsearch.xpack.security.authc.saml.SamlUtils; import org.opensaml.saml.saml2.core.LogoutResponse; -import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.stream.Collectors; +import java.util.function.Predicate; import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.findSamlRealms; @@ -85,7 +85,7 @@ public final class TransportSamlInvalidateSessionAction private void findAndInvalidateTokens(SamlRealm realm, SamlLogoutRequestHandler.Result result, ActionListener listener) { final Map tokenMetadata = realm.createTokenMetadata(result.getNameId(), result.getSession()); - if (Strings.hasText((String) tokenMetadata.get(SamlRealm.TOKEN_METADATA_NAMEID_VALUE)) == false) { + if (Strings.isNullOrEmpty((String) tokenMetadata.get(SamlRealm.TOKEN_METADATA_NAMEID_VALUE))) { // If we don't have a valid name-id to match against, don't do anything logger.debug("Logout request [{}] has no NameID value, so cannot invalidate any sessions", result); listener.onResponse(0); @@ -93,22 +93,21 @@ public final class TransportSamlInvalidateSessionAction } tokenService.findActiveTokensForRealm(realm.name(), ActionListener.wrap(tokens -> { - List> sessionTokens = filterTokens(tokens, tokenMetadata); - logger.debug("Found [{}] token pairs to invalidate for SAML metadata [{}]", sessionTokens.size(), tokenMetadata); - if (sessionTokens.isEmpty()) { - listener.onResponse(0); - } else { - GroupedActionListener groupedListener = new GroupedActionListener<>( - ActionListener.wrap(collection -> listener.onResponse(collection.size()), listener::onFailure), - sessionTokens.size(), Collections.emptyList() - ); - sessionTokens.forEach(tuple -> invalidateTokenPair(tuple, groupedListener)); - } - }, e -> listener.onFailure(e) - )); + logger.debug("Found [{}] token pairs to invalidate for SAML metadata [{}]", tokens.size(), tokenMetadata); + if (tokens.isEmpty()) { + listener.onResponse(0); + } else { + GroupedActionListener groupedListener = new GroupedActionListener<>( + ActionListener.wrap(collection -> listener.onResponse(collection.size()), listener::onFailure), + tokens.size(), Collections.emptyList() + ); + tokens.forEach(tuple -> invalidateTokenPair(tuple, groupedListener)); + } + }, listener::onFailure + ), containsMetadata(tokenMetadata)); } - private void invalidateTokenPair(Tuple tokenPair, ActionListener listener) { + private void invalidateTokenPair(Tuple tokenPair, ActionListener listener) { // Invalidate the refresh token first, so the client doesn't trigger a refresh once the access token is invalidated tokenService.invalidateRefreshToken(tokenPair.v2(), ActionListener.wrap(ignore -> tokenService.invalidateAccessToken( tokenPair.v1(), @@ -118,13 +117,12 @@ public final class TransportSamlInvalidateSessionAction })), listener::onFailure)); } - private List> filterTokens(Collection> tokens, Map requiredMetadata) { - return tokens.stream() - .filter(tup -> { - Map actualMetadata = tup.v1().getMetadata(); - return requiredMetadata.entrySet().stream().allMatch(e -> Objects.equals(actualMetadata.get(e.getKey()), e.getValue())); - }) - .collect(Collectors.toList()); + + private Predicate> containsMetadata(Map requiredMetadata) { + return source -> { + Map actualMetadata = (Map) source.get("metadata"); + return requiredMetadata.entrySet().stream().allMatch(e -> Objects.equals(actualMetadata.get(e.getKey()), e.getValue())); + }; } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutAction.java index b62702ead78..28e9f911cd5 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutAction.java @@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutRequest; import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutResponse; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Realm; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.security.authc.Realms; import org.elasticsearch.xpack.security.authc.TokenService; @@ -79,7 +80,7 @@ public final class TransportSamlLogoutAction }, listener::onFailure)); } - private void invalidateRefreshToken(String refreshToken, ActionListener listener) { + private void invalidateRefreshToken(String refreshToken, ActionListener listener) { if (refreshToken == null) { listener.onResponse(null); } else { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportInvalidateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportInvalidateTokenAction.java index 70f614435fc..9f0443a86f7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportInvalidateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/token/TransportInvalidateTokenAction.java @@ -8,12 +8,14 @@ package org.elasticsearch.xpack.security.action.token; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenAction; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.security.authc.TokenService; /** @@ -31,9 +33,12 @@ public final class TransportInvalidateTokenAction extends HandledTransportAction @Override protected void doExecute(Task task, InvalidateTokenRequest request, ActionListener listener) { - final ActionListener invalidateListener = - ActionListener.wrap(created -> listener.onResponse(new InvalidateTokenResponse(created)), listener::onFailure); - if (request.getTokenType() == InvalidateTokenRequest.Type.ACCESS_TOKEN) { + final ActionListener invalidateListener = + ActionListener.wrap(tokensInvalidationResult -> + listener.onResponse(new InvalidateTokenResponse(tokensInvalidationResult)), listener::onFailure); + if (Strings.hasText(request.getUserName()) || Strings.hasText(request.getRealmName())) { + tokenService.invalidateActiveTokensForRealmAndUser(request.getRealmName(), request.getUserName(), invalidateListener); + } else if (request.getTokenType() == InvalidateTokenRequest.Type.ACCESS_TOKEN) { tokenService.invalidateAccessToken(request.getTokenString(), invalidateListener); } else { assert request.getTokenType() == InvalidateTokenRequest.Type.REFRESH_TOKEN; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index be5b11aa666..15d3e758426 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -17,6 +17,11 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest.OpType; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.get.MultiGetItemResponse; @@ -24,7 +29,6 @@ import org.elasticsearch.action.get.MultiGetRequest; import org.elasticsearch.action.get.MultiGetResponse; import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; -import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; @@ -39,6 +43,7 @@ import org.elasticsearch.cluster.ack.AckedRequest; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.cluster.metadata.MetaData; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Nullable; import org.elasticsearch.common.Priority; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; @@ -61,7 +66,6 @@ import org.elasticsearch.common.util.iterable.Iterables; import org.elasticsearch.common.xcontent.XContentBuilder; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.core.internal.io.IOUtils; -import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.index.engine.VersionConflictEngineException; import org.elasticsearch.index.query.BoolQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; @@ -74,6 +78,7 @@ import org.elasticsearch.xpack.core.security.ScrollHelper; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp; import org.elasticsearch.xpack.core.security.authc.TokenMetaData; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.security.support.SecurityIndexManager; import javax.crypto.Cipher; @@ -90,6 +95,7 @@ import java.io.ByteArrayOutputStream; import java.io.Closeable; import java.io.IOException; import java.io.OutputStream; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.GeneralSecurityException; @@ -116,6 +122,8 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException; import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK; @@ -221,9 +229,9 @@ public final class TokenService { boolean includeRefreshToken) throws IOException { ensureEnabled(); if (authentication == null) { - listener.onFailure(traceLog("create token", null, new IllegalArgumentException("authentication must be provided"))); + listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided"))); } else if (originatingClientAuth == null) { - listener.onFailure(traceLog("create token", null, + listener.onFailure(traceLog("create token", new IllegalArgumentException("originating client authentication must be provided"))); } else { final Instant created = clock.instant(); @@ -471,7 +479,7 @@ public final class TokenService { * have been created on versions on or after 6.2; this step involves performing an update to * the token document and setting the invalidated field to true */ - public void invalidateAccessToken(String tokenString, ActionListener listener) { + public void invalidateAccessToken(String tokenString, ActionListener listener) { ensureEnabled(); if (Strings.isNullOrEmpty(tokenString)) { logger.trace("No token-string provided"); @@ -484,7 +492,8 @@ public final class TokenService { listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException())); } else { final long expirationEpochMilli = getExpirationTime().toEpochMilli(); - indexBwcInvalidation(userToken, listener, new AtomicInteger(0), expirationEpochMilli); + indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), + expirationEpochMilli, null); } }, listener::onFailure)); } catch (IOException e) { @@ -499,7 +508,7 @@ public final class TokenService { * * @see #invalidateAccessToken(String, ActionListener) */ - public void invalidateAccessToken(UserToken userToken, ActionListener listener) { + public void invalidateAccessToken(UserToken userToken, ActionListener listener) { ensureEnabled(); if (userToken == null) { logger.trace("No access token provided"); @@ -507,11 +516,17 @@ public final class TokenService { } else { maybeStartTokenRemover(); final long expirationEpochMilli = getExpirationTime().toEpochMilli(); - indexBwcInvalidation(userToken, listener, new AtomicInteger(0), expirationEpochMilli); + indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), expirationEpochMilli, null); } } - public void invalidateRefreshToken(String refreshToken, ActionListener listener) { + /** + * This method performs the steps necessary to invalidate a refresh token so that it may no longer be used. + * + * @param refreshToken The string representation of the refresh token + * @param listener the listener to notify upon completion + */ + public void invalidateRefreshToken(String refreshToken, ActionListener listener) { ensureEnabled(); if (Strings.isNullOrEmpty(refreshToken)) { logger.trace("No refresh token provided"); @@ -520,152 +535,222 @@ public final class TokenService { maybeStartTokenRemover(); findTokenFromRefreshToken(refreshToken, ActionListener.wrap(tuple -> { - final String docId = tuple.v1().getHits().getAt(0).getId(); - final long docVersion = tuple.v1().getHits().getAt(0).getVersion(); - indexInvalidation(docId, Version.CURRENT, listener, tuple.v2(), "refresh_token", docVersion); + final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId()); + indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null); }, listener::onFailure), new AtomicInteger(0)); } } /** - * Performs the actual bwc invalidation of a token and then kicks off the new invalidation method + * Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given + * {@code username} so that they may no longer be used * - * @param userToken the token to invalidate - * @param listener the listener to notify upon completion - * @param attemptCount the number of attempts to invalidate that have already been tried - * @param expirationEpochMilli the expiration time as milliseconds since the epoch + * @param realmName the realm of which the tokens should be invalidated + * @param username the username for which the tokens should be invalidated + * @param listener the listener to notify upon completion */ - private void indexBwcInvalidation(UserToken userToken, ActionListener listener, AtomicInteger attemptCount, - long expirationEpochMilli) { - if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to invalidate token [{}] after [{}] attempts", userToken.getId(), attemptCount.get()); - listener.onFailure(invalidGrantException("failed to invalidate token")); + public void invalidateActiveTokensForRealmAndUser(@Nullable String realmName, @Nullable String username, + ActionListener listener) { + ensureEnabled(); + if (Strings.isNullOrEmpty(realmName) && Strings.isNullOrEmpty(username)) { + logger.trace("No realm name or username provided"); + listener.onFailure(new IllegalArgumentException("realm name or username must be provided")); } else { - final String invalidatedTokenId = getInvalidatedTokenDocumentId(userToken); - IndexRequest indexRequest = client.prepareIndex(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId) - .setOpType(OpType.CREATE) - .setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) - .request(); - final String tokenDocId = getTokenDocumentId(userToken); - final Version version = userToken.getVersion(); - securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", tokenDocId, ex)), - () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, indexRequest, - ActionListener.wrap(indexResponse -> { - ActionListener wrappedListener = - ActionListener.wrap(ignore -> listener.onResponse(true), listener::onFailure); - indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L); - }, e -> { - Throwable cause = ExceptionsHelper.unwrapCause(e); - traceLog("(bwc) invalidate token", tokenDocId, cause); - if (cause instanceof VersionConflictEngineException) { - // expected since something else could have invalidated - ActionListener wrappedListener = - ActionListener.wrap(ignore -> listener.onResponse(false), listener::onFailure); - indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L); - } else if (isShardNotAvailableException(e)) { - attemptCount.incrementAndGet(); - indexBwcInvalidation(userToken, listener, attemptCount, expirationEpochMilli); - } else { - listener.onFailure(e); - } - }), client::index)); + if (Strings.isNullOrEmpty(realmName)) { + findActiveTokensForUser(username, ActionListener.wrap(tokenTuples -> { + if (tokenTuples.isEmpty()) { + logger.warn("No tokens to invalidate for realm [{}] and username [{}]", realmName, username); + listener.onResponse(TokensInvalidationResult.emptyResult()); + } else { + invalidateAllTokens(tokenTuples.stream().map(t -> t.v1().getId()).collect(Collectors.toList()), listener); + } + }, listener::onFailure)); + } else { + Predicate filter = null; + if (Strings.hasText(username)) { + filter = isOfUser(username); + } + findActiveTokensForRealm(realmName, ActionListener.wrap(tokenTuples -> { + if (tokenTuples.isEmpty()) { + logger.warn("No tokens to invalidate for realm [{}] and username [{}]", realmName, username); + listener.onResponse(TokensInvalidationResult.emptyResult()); + } else { + invalidateAllTokens(tokenTuples.stream().map(t -> t.v1().getId()).collect(Collectors.toList()), listener); + } + }, listener::onFailure), filter); + } } } /** - * Performs the actual invalidation of a token + * Invalidates a collection of access_token and refresh_token that were retrieved by + * {@link TokenService#invalidateActiveTokensForRealmAndUser} * - * @param tokenDocId the id of the token doc to invalidate + * @param accessTokenIds The ids of the access tokens which should be invalidated (along with the respective refresh_token) + * @param listener the listener to notify upon completion + */ + private void invalidateAllTokens(Collection accessTokenIds, ActionListener listener) { + maybeStartTokenRemover(); + final long expirationEpochMilli = getExpirationTime().toEpochMilli(); + // Invalidate the refresh tokens first so that they cannot be used to get new + // access tokens while we invalidate the access tokens we currently know about + indexInvalidation(accessTokenIds, ActionListener.wrap(result -> + indexBwcInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()), + expirationEpochMilli, result), + listener::onFailure), new AtomicInteger(0), "refresh_token", null); + } + + /** + * Performs the actual bwc invalidation of a collection of tokens and then kicks off the new invalidation method. + * + * @param tokenIds the collection of token ids or token document ids that should be invalidated + * @param listener the listener to notify upon completion + * @param attemptCount the number of attempts to invalidate that have already been tried + * @param expirationEpochMilli the expiration time as milliseconds since the epoch + * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating + * tokens up to the point of the retry. This result is added to the result of the current attempt + */ + private void indexBwcInvalidation(Collection tokenIds, ActionListener listener, + AtomicInteger attemptCount, long expirationEpochMilli, + @Nullable TokensInvalidationResult previousResult) { + + if (tokenIds.isEmpty()) { + logger.warn("No tokens provided for invalidation"); + listener.onFailure(invalidGrantException("No tokens provided for invalidation")); + } else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { + logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(), + attemptCount.get()); + listener.onFailure(invalidGrantException("failed to invalidate tokens")); + } else { + BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); + for (String tokenId : tokenIds) { + final String invalidatedTokenId = getInvalidatedTokenDocumentId(tokenId); + IndexRequest indexRequest = client.prepareIndex(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId) + .setOpType(OpType.CREATE) + .setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli) + .request(); + bulkRequestBuilder.add(indexRequest); + } + bulkRequestBuilder.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); + final BulkRequest bulkRequest = bulkRequestBuilder.request(); + securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", ex)), + () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequest, + ActionListener.wrap(bulkResponse -> { + List retryTokenIds = new ArrayList<>(); + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + if (bulkItemResponse.isFailed()) { + Throwable cause = bulkItemResponse.getFailure().getCause(); + logger.error(cause.getMessage()); + traceLog("(bwc) invalidate tokens", cause); + if (isShardNotAvailableException(cause)) { + retryTokenIds.add(getTokenIdFromInvalidatedTokenDocumentId(bulkItemResponse.getFailure().getId())); + } else if ((cause instanceof VersionConflictEngineException) == false){ + // We don't handle VersionConflictEngineException, the ticket has been invalidated + listener.onFailure(bulkItemResponse.getFailure().getCause()); + } + } + } + if (retryTokenIds.isEmpty() == false) { + attemptCount.incrementAndGet(); + indexBwcInvalidation(retryTokenIds, listener, attemptCount, expirationEpochMilli, previousResult); + } + indexInvalidation(tokenIds, listener, attemptCount, "access_token", previousResult); + }, e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + traceLog("(bwc) invalidate tokens", cause); + if (isShardNotAvailableException(cause)) { + attemptCount.incrementAndGet(); + indexBwcInvalidation(tokenIds, listener, attemptCount, expirationEpochMilli, previousResult); + } else { + listener.onFailure(e); + } + }), + client::bulk)); + } + } + + /** + * Performs the actual invalidation of a collection of tokens + * + * @param tokenIds the tokens to invalidate * @param listener the listener to notify upon completion * @param attemptCount the number of attempts to invalidate that have already been tried - * @param srcPrefix the prefix to use when constructing the doc to update - * @param documentVersion the expected version of the document we will update + * @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on + * what type of tokens should be invalidated + * @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating + * tokens up to the point of the retry. This result is added to the result of the current attempt */ - private void indexInvalidation(String tokenDocId, Version version, ActionListener listener, AtomicInteger attemptCount, - String srcPrefix, long documentVersion) { - if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { - logger.warn("Failed to invalidate token [{}] after [{}] attempts", tokenDocId, attemptCount.get()); - listener.onFailure(invalidGrantException("failed to invalidate token")); + private void indexInvalidation(Collection tokenIds, ActionListener listener, + AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) { + if (tokenIds.isEmpty()) { + logger.warn("No [{}] tokens provided for invalidation", srcPrefix); + listener.onFailure(invalidGrantException("No tokens provided for invalidation")); + } else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) { + logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(), + attemptCount.get()); + listener.onFailure(invalidGrantException("failed to invalidate tokens")); } else { - UpdateRequest request = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId) + BulkRequestBuilder bulkRequestBuilder = client.prepareBulk(); + for (String tokenId : tokenIds) { + UpdateRequest request = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, getTokenDocumentId(tokenId)) .setDoc(srcPrefix, Collections.singletonMap("invalidated", true)) - .setVersion(documentVersion) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .setFetchSource(srcPrefix, null) .request(); - securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", tokenDocId, ex)), - () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, - ActionListener.wrap(updateResponse -> { - logger.debug("Invalidated [{}] for doc [{}]", srcPrefix, tokenDocId); - if (updateResponse.getGetResult() != null - && updateResponse.getGetResult().sourceAsMap().containsKey(srcPrefix) - && ((Map) updateResponse.getGetResult().sourceAsMap().get(srcPrefix)) - .containsKey("invalidated")) { - final boolean prevInvalidated = (boolean) - ((Map) updateResponse.getGetResult().sourceAsMap().get(srcPrefix)) - .get("invalidated"); - listener.onResponse(prevInvalidated == false); - } else { - listener.onResponse(true); + bulkRequestBuilder.add(request); + } + bulkRequestBuilder.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL); + securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", ex)), + () -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequestBuilder.request(), + ActionListener.wrap(bulkResponse -> { + ArrayList retryTokenDocIds = new ArrayList<>(); + ArrayList failedRequestResponses = new ArrayList<>(); + ArrayList previouslyInvalidated = new ArrayList<>(); + ArrayList invalidated = new ArrayList<>(); + if (null != previousResult) { + failedRequestResponses.addAll((previousResult.getErrors())); + previouslyInvalidated.addAll(previousResult.getPreviouslyInvalidatedTokens()); + invalidated.addAll(previousResult.getInvalidatedTokens()); } + for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) { + if (bulkItemResponse.isFailed()) { + Throwable cause = bulkItemResponse.getFailure().getCause(); + final String failedTokenDocId = getTokenIdFromDocumentId(bulkItemResponse.getFailure().getId()); + if (isShardNotAvailableException(cause)) { + retryTokenDocIds.add(failedTokenDocId); + } + else { + traceLog("invalidate access token", failedTokenDocId, cause); + failedRequestResponses.add(new ElasticsearchException("Error invalidating " + srcPrefix + ": ", cause)); + } + } else { + UpdateResponse updateResponse = bulkItemResponse.getResponse(); + if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) { + logger.debug("Invalidated [{}] for doc [{}]", srcPrefix, updateResponse.getGetResult().getId()); + invalidated.add(updateResponse.getGetResult().getId()); + } else if (updateResponse.getResult() == DocWriteResponse.Result.NOOP) { + previouslyInvalidated.add(updateResponse.getGetResult().getId()); + } + } + } + if (retryTokenDocIds.isEmpty() == false) { + TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses, attemptCount.get()); + attemptCount.incrementAndGet(); + indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult); + } + TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated, + failedRequestResponses, attemptCount.get()); + listener.onResponse(result); }, e -> { Throwable cause = ExceptionsHelper.unwrapCause(e); - traceLog("invalidate token", tokenDocId, cause); - if (cause instanceof DocumentMissingException) { - if (version.onOrAfter(Version.V_6_2_0)) { - // the document should always be there! - listener.onFailure(e); - } else { - listener.onResponse(false); - } - } else if (cause instanceof VersionConflictEngineException - || isShardNotAvailableException(cause)) { + traceLog("invalidate tokens", cause); + if (isShardNotAvailableException(cause)) { attemptCount.incrementAndGet(); - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, - client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(), - ActionListener.wrap(getResult -> { - if (getResult.isExists()) { - Map source = getResult.getSource(); - Map accessTokenSource = (Map) source.get("access_token"); - Consumer onFailure = ex -> listener.onFailure(traceLog("get token", tokenDocId, ex)); - if (accessTokenSource == null) { - onFailure.accept(new IllegalArgumentException( - "token document is missing access_token field")); - } else { - Boolean invalidated = (Boolean) accessTokenSource.get("invalidated"); - if (invalidated == null) { - onFailure.accept(new IllegalStateException( - "token document missing invalidated value")); - } else if (invalidated) { - logger.trace("Token [{}] is already invalidated", tokenDocId); - listener.onResponse(false); - } else { - indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix, - getResult.getVersion()); - } - } - } else if (version.onOrAfter(Version.V_6_2_0)) { - logger.warn("could not find token document [{}] but there should " + - "be one as token has version [{}]", tokenDocId, version); - listener.onFailure(invalidGrantException("could not invalidate the token")); - } else { - listener.onResponse(false); - } - }, - e1 -> { - traceLog("get token", tokenDocId, e1); - if (isShardNotAvailableException(e1)) { - // don't increment count; call again - indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix, - documentVersion); - } else { - listener.onFailure(e1); - } - }), client::get); + indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult); } else { listener.onFailure(e); } - }), client::update)); + }), client::bulk)); } } @@ -676,12 +761,12 @@ public final class TokenService { public void refreshToken(String refreshToken, ActionListener> listener) { ensureEnabled(); findTokenFromRefreshToken(refreshToken, - ActionListener.wrap(tuple -> { - final Authentication userAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); - final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); - innerRefresh(tokenDocId, userAuth, listener, tuple.v2()); - }, listener::onFailure), - new AtomicInteger(0)); + ActionListener.wrap(tuple -> { + final Authentication userAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); + final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); + innerRefresh(tokenDocId, userAuth, listener, tuple.v2()); + }, listener::onFailure), + new AtomicInteger(0)); } private void findTokenFromRefreshToken(String refreshToken, ActionListener> listener, @@ -691,11 +776,11 @@ public final class TokenService { listener.onFailure(invalidGrantException("could not refresh the requested token")); } else { SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) - .setQuery(QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("doc_type", "token")) - .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) - .setVersion(true) - .request(); + .setQuery(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", "token")) + .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) + .setVersion(true) + .request(); final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); if (frozenSecurityIndex.indexExists() == false) { @@ -860,12 +945,16 @@ public final class TokenService { } /** - * Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against + * Find stored refresh and access tokens that have not been invalidated or expired, and were issued against * the specified realm. + * + * @param realmName The name of the realm for which to get the tokens + * @param listener The listener to notify upon completion + * @param filter an optional Predicate to test the source of the found documents against */ - public void findActiveTokensForRealm(String realmName, ActionListener>> listener) { + public void findActiveTokensForRealm(String realmName, ActionListener>> listener, + @Nullable Predicate> filter) { ensureEnabled(); - final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); if (Strings.isNullOrEmpty(realmName)) { listener.onFailure(new IllegalArgumentException("Realm name is required")); @@ -883,7 +972,10 @@ public final class TokenService { .must(QueryBuilders.termQuery("access_token.invalidated", false)) .must(QueryBuilders.rangeQuery("access_token.user_token.expiration_time").gte(now.toEpochMilli())) ) - .should(QueryBuilders.termQuery("refresh_token.invalidated", false)) + .should(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("refresh_token.invalidated", false)) + .must(QueryBuilders.rangeQuery("creation_time").gte(now.toEpochMilli() - TimeValue.timeValueHours(24).millis())) + ) ); final SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) @@ -893,33 +985,102 @@ public final class TokenService { .setSize(1000) .setFetchSource(true) .request(); - securityIndex.checkIndexVersionThenExecute(listener::onFailure, - () -> ScrollHelper.fetchAllByEntity(client, request, listener, this::parseHit)); + () -> ScrollHelper.fetchAllByEntity(client, request, listener, (SearchHit hit) -> filterAndParseHit(hit, filter))); } } - private Tuple parseHit(SearchHit hit) { + /** + * Find stored refresh and access tokens that have not been invalidated or expired, and were issued for + * the specified user. + * + * @param username The user for which to get the tokens + * @param listener The listener to notify upon completion + */ + public void findActiveTokensForUser(String username, ActionListener>> listener) { + ensureEnabled(); + + final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze(); + if (Strings.isNullOrEmpty(username)) { + listener.onFailure(new IllegalArgumentException("username is required")); + } else if (frozenSecurityIndex.indexExists() == false) { + listener.onResponse(Collections.emptyList()); + } else if (frozenSecurityIndex.isAvailable() == false) { + listener.onFailure(frozenSecurityIndex.getUnavailableReason()); + } else { + final Instant now = clock.instant(); + final BoolQueryBuilder boolQuery = QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", "token")) + .filter(QueryBuilders.boolQuery() + .should(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("access_token.invalidated", false)) + .must(QueryBuilders.rangeQuery("access_token.user_token.expiration_time").gte(now.toEpochMilli())) + ) + .should(QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("refresh_token.invalidated", false)) + .must(QueryBuilders.rangeQuery("creation_time").gte(now.toEpochMilli() - TimeValue.timeValueHours(24).millis())) + ) + ); + + final SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) + .setScroll(DEFAULT_KEEPALIVE_SETTING.get(settings)) + .setQuery(boolQuery) + .setVersion(false) + .setSize(1000) + .setFetchSource(true) + .request(); + securityIndex.checkIndexVersionThenExecute(listener::onFailure, + () -> ScrollHelper.fetchAllByEntity(client, request, listener, + (SearchHit hit) -> filterAndParseHit(hit, isOfUser(username)))); + } + } + + private static Predicate> isOfUser(String username) { + return source -> { + String auth = (String) source.get("authentication"); + Integer version = (Integer) source.get("version"); + Version authVersion = Version.fromId(version); + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(auth))) { + in.setVersion(authVersion); + Authentication authentication = new Authentication(in); + return authentication.getUser().principal().equals(username); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + }; + } + + + private Tuple filterAndParseHit(SearchHit hit, @Nullable Predicate> filter) { final Map source = hit.getSourceAsMap(); if (source == null) { throw new IllegalStateException("token document did not have source but source should have been fetched"); } - try { - return parseTokensFromDocument(source); + return parseTokensFromDocument(source, filter); } catch (IOException e) { throw invalidGrantException("cannot read token from document"); } } /** - * @return A {@link Tuple} of access-token and refresh-token-id + * + * Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token + * + * @param source The token document source as retrieved + * @param filter an optional Predicate to test the source of the UserToken against + * @return A {@link Tuple} of access-token and refresh-token-id or null if a Predicate is defined and the userToken source doesn't + * satisfy it */ - private Tuple parseTokensFromDocument(Map source) throws IOException { - final String refreshToken = (String) ((Map) source.get("refresh_token")).get("token"); + private Tuple parseTokensFromDocument(Map source, @Nullable Predicate> filter) + throws IOException { + final String refreshToken = (String) ((Map) source.get("refresh_token")).get("token"); final Map userTokenSource = (Map) - ((Map) source.get("access_token")).get("user_token"); + ((Map) source.get("access_token")).get("user_token"); + if (null != filter && filter.test(userTokenSource) == false) { + return null; + } final String id = (String) userTokenSource.get("id"); final Integer version = (Integer) userTokenSource.get("version"); final String authString = (String) userTokenSource.get("authentication"); @@ -951,6 +1112,23 @@ public final class TokenService { return "token_" + id; } + private static String getTokenIdFromDocumentId(String docId) { + if (docId.startsWith("token_") == false) { + throw new IllegalStateException("TokenDocument ID [" + docId + "] has unexpected value"); + } else { + return docId.substring("token_".length()); + } + } + + private static String getTokenIdFromInvalidatedTokenDocumentId(String docId) { + final String invalidatedTokenDocPrefix = INVALIDATED_TOKEN_DOC_TYPE + "_"; + if (docId.startsWith(invalidatedTokenDocPrefix) == false) { + throw new IllegalStateException("InvalidatedTokenDocument ID [" + docId + "] has unexpected value"); + } else { + return docId.substring(invalidatedTokenDocPrefix.length()); + } + } + private void ensureEnabled() { if (enabled == false) { throw new IllegalStateException("tokens are not enabled"); @@ -1149,7 +1327,7 @@ public final class TokenService { } /** - * Creates an {@link ElasticsearchSecurityException} that indicates the token was expired. It + * Creates an {@link ElasticsearchSecurityException} that indicates the token was malformed. It * is up to the client to re-authenticate and obtain a new token. The format for this response * is defined in */ @@ -1171,7 +1349,7 @@ public final class TokenService { } /** - * Logs an exception at TRACE level (if enabled) + * Logs an exception concerning a specific Token at TRACE level (if enabled) */ private E traceLog(String action, String identifier, E exception) { if (logger.isTraceEnabled()) { @@ -1179,12 +1357,34 @@ public final class TokenService { final ElasticsearchException esEx = (ElasticsearchException) exception; final Object detail = esEx.getHeader("error_description"); if (detail != null) { - logger.trace("Failure in [{}] for id [{}] - [{}] [{}]", action, identifier, detail, esEx.getDetailedMessage()); + logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}] - [{}]", action, identifier, detail), + esEx); } else { - logger.trace("Failure in [{}] for id [{}] - [{}]", action, identifier, esEx.getDetailedMessage()); + logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}]", action, identifier), + esEx); } } else { - logger.trace("Failure in [{}] for id [{}] - [{}]", action, identifier, exception.toString()); + logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}]", action, identifier), exception); + } + } + return exception; + } + + /** + * Logs an exception at TRACE level (if enabled) + */ + private E traceLog(String action, E exception) { + if (logger.isTraceEnabled()) { + if (exception instanceof ElasticsearchException) { + final ElasticsearchException esEx = (ElasticsearchException) exception; + final Object detail = esEx.getHeader("error_description"); + if (detail != null) { + logger.trace(() -> new ParameterizedMessage("Failure in [{}] - [{}]", action, detail), esEx); + } else { + logger.trace(() -> new ParameterizedMessage("Failure in [{}]", action), esEx); + } + } else { + logger.trace(() -> new ParameterizedMessage("Failure in [{}]", action), exception); } } return exception; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java index 52228d2823a..9801f3c93c8 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java @@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.ParseField; import org.elasticsearch.common.Strings; -import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; @@ -37,11 +36,32 @@ import static org.elasticsearch.rest.RestRequest.Method.DELETE; public final class RestInvalidateTokenAction extends SecurityBaseRestHandler { private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(RestInvalidateTokenAction.class)); - static final ConstructingObjectParser, Void> PARSER = - new ConstructingObjectParser<>("invalidate_token", a -> new Tuple<>((String) a[0], (String) a[1])); + static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>("invalidate_token", a -> { + final String token = (String) a[0]; + final String refreshToken = (String) a[1]; + final String tokenString; + final String tokenType; + if (Strings.hasLength(token) && Strings.hasLength(refreshToken)) { + throw new IllegalArgumentException("only one of [token, refresh_token] may be sent per request"); + } else if (Strings.hasLength(token)) { + tokenString = token; + tokenType = InvalidateTokenRequest.Type.ACCESS_TOKEN.getValue(); + } else if (Strings.hasLength(refreshToken)) { + tokenString = refreshToken; + tokenType = InvalidateTokenRequest.Type.REFRESH_TOKEN.getValue(); + } else { + tokenString = null; + tokenType = null; + } + return new InvalidateTokenRequest(tokenString, tokenType, (String) a[2], (String) a[3]); + }); + static { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("token")); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refresh_token")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("realm_name")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("username")); } public RestInvalidateTokenAction(Settings settings, RestController controller, XPackLicenseState xPackLicenseState) { @@ -60,36 +80,16 @@ public final class RestInvalidateTokenAction extends SecurityBaseRestHandler { @Override protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client) throws IOException { try (XContentParser parser = request.contentParser()) { - final Tuple tuple = PARSER.parse(parser, null); - final String token = tuple.v1(); - final String refreshToken = tuple.v2(); - - final String tokenString; - final InvalidateTokenRequest.Type type; - if (Strings.hasLength(token) && Strings.hasLength(refreshToken)) { - throw new IllegalArgumentException("only one of [token, refresh_token] may be sent per request"); - } else if (Strings.hasLength(token)) { - tokenString = token; - type = InvalidateTokenRequest.Type.ACCESS_TOKEN; - } else if (Strings.hasLength(refreshToken)) { - tokenString = refreshToken; - type = InvalidateTokenRequest.Type.REFRESH_TOKEN; - } else { - tokenString = null; - type = null; - } - - final InvalidateTokenRequest tokenRequest = new InvalidateTokenRequest(tokenString, type); - return channel -> client.execute(InvalidateTokenAction.INSTANCE, tokenRequest, - new RestBuilderListener(channel) { - @Override - public RestResponse buildResponse(InvalidateTokenResponse invalidateResp, - XContentBuilder builder) throws Exception { - return new BytesRestResponse(RestStatus.OK, builder.startObject() - .field("created", invalidateResp.isCreated()) - .endObject()); - } - }); + final InvalidateTokenRequest invalidateTokenRequest = PARSER.parse(parser, null); + return channel -> client.execute(InvalidateTokenAction.INSTANCE, invalidateTokenRequest, + new RestBuilderListener(channel) { + @Override + public RestResponse buildResponse(InvalidateTokenResponse invalidateResp, + XContentBuilder builder) throws Exception { + invalidateResp.toXContent(builder, channel.request()); + return new BytesRestResponse(RestStatus.OK, builder); + } + }); } } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java index ba1d1762f06..5a4c8f3bde8 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlInvalidateSessionActionTests.java @@ -11,6 +11,10 @@ import org.elasticsearch.action.Action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionResponse; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.index.IndexAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; @@ -21,11 +25,11 @@ import org.elasticsearch.action.search.SearchAction; import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchResponseSections; +import org.elasticsearch.action.search.SearchScrollAction; +import org.elasticsearch.action.search.SearchScrollRequest; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.update.UpdateAction; import org.elasticsearch.action.update.UpdateRequest; -import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.bytes.BytesReference; @@ -106,11 +110,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase { private SamlRealm samlRealm; private TokenService tokenService; private List indexRequests; - private List updateRequests; + private List bulkRequests; private List searchRequests; private TransportSamlInvalidateSessionAction action; private SamlLogoutRequestHandler.Result logoutRequest; private Function searchFunction = ignore -> new SearchHit[0]; + private Function searchScrollFunction = ignore -> new SearchHit[0]; @Before public void setup() throws Exception { @@ -132,8 +137,8 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase { new Authentication(new User("kibana"), new RealmRef("realm", "type", "node"), null).writeToContext(threadContext); indexRequests = new ArrayList<>(); - updateRequests = new ArrayList<>(); searchRequests = new ArrayList<>(); + bulkRequests = new ArrayList<>(); final Client client = new NoOpClient(threadPool) { @Override protected @@ -143,20 +148,29 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase { IndexRequest indexRequest = (IndexRequest) request; indexRequests.add(indexRequest); final IndexResponse response = new IndexResponse( - indexRequest.shardId(), indexRequest.type(), indexRequest.id(), 1, 1, 1, true); + indexRequest.shardId(), indexRequest.type(), indexRequest.id(), 1, 1, 1, true); + listener.onResponse((Response) response); + } else if (BulkAction.NAME.equals(action.name())) { + assertThat(request, instanceOf(BulkRequest.class)); + bulkRequests.add((BulkRequest) request); + final BulkResponse response = new BulkResponse(new BulkItemResponse[0], 1); listener.onResponse((Response) response); - } else if (UpdateAction.NAME.equals(action.name())) { - assertThat(request, instanceOf(UpdateRequest.class)); - updateRequests.add((UpdateRequest) request); - listener.onResponse((Response) new UpdateResponse()); } else if (SearchAction.NAME.equals(action.name())) { assertThat(request, instanceOf(SearchRequest.class)); SearchRequest searchRequest = (SearchRequest) request; searchRequests.add(searchRequest); final SearchHit[] hits = searchFunction.apply(searchRequest); final SearchResponse response = new SearchResponse( - new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f), - null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null); + new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f), + null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null); + listener.onResponse((Response) response); + } else if (SearchScrollAction.NAME.equals(action.name())){ + assertThat(request, instanceOf(SearchScrollRequest.class)); + SearchScrollRequest searchScrollRequest = (SearchScrollRequest) request; + final SearchHit[] hits = searchScrollFunction.apply(searchScrollRequest); + final SearchResponse response = new SearchResponse( + new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f), + null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null); listener.onResponse((Response) response); } else if (ClearScrollAction.NAME.equals(action.name())) { assertThat(request, instanceOf(ClearScrollRequest.class)); @@ -296,15 +310,33 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase { assertThat(((TermQueryBuilder) filter1.get(1)).fieldName(), equalTo("refresh_token.token")); assertThat(((TermQueryBuilder) filter1.get(1)).value(), equalTo(tokenToInvalidate1.v2())); - assertThat(updateRequests.size(), equalTo(4)); // (refresh-token + access-token) * 2 - assertThat(updateRequests.get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId())); - assertThat(updateRequests.get(1).id(), equalTo(updateRequests.get(0).id())); - assertThat(updateRequests.get(2).id(), equalTo("token_" + tokenToInvalidate2.v1().getId())); - assertThat(updateRequests.get(3).id(), equalTo(updateRequests.get(2).id())); - - assertThat(indexRequests.size(), equalTo(2)); // bwc-invalidate * 2 - assertThat(indexRequests.get(0).id(), startsWith("invalidated-token_")); - assertThat(indexRequests.get(1).id(), startsWith("invalidated-token_")); + assertThat(bulkRequests.size(), equalTo(6)); // 4 updates (refresh-token + access-token) plus 2 indexes (bwc-invalidate * 2) + // Invalidate refresh token 1 + assertThat(bulkRequests.get(0).requests().get(0), instanceOf(UpdateRequest.class)); + assertThat(bulkRequests.get(0).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId())); + UpdateRequest updateRequest1 = (UpdateRequest) bulkRequests.get(0).requests().get(0); + assertThat(updateRequest1.toString().contains("refresh_token"), equalTo(true)); + // BWC incalidate access token 1 + assertThat(bulkRequests.get(1).requests().get(0), instanceOf(IndexRequest.class)); + assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate1.v1().getId())); + // Invalidate access token 1 + assertThat(bulkRequests.get(2).requests().get(0), instanceOf(UpdateRequest.class)); + assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId())); + UpdateRequest updateRequest2 = (UpdateRequest) bulkRequests.get(2).requests().get(0); + assertThat(updateRequest2.toString().contains("access_token"), equalTo(true)); + // Invalidate refresh token 2 + assertThat(bulkRequests.get(3).requests().get(0), instanceOf(UpdateRequest.class)); + assertThat(bulkRequests.get(3).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId())); + UpdateRequest updateRequest3 = (UpdateRequest) bulkRequests.get(3).requests().get(0); + assertThat(updateRequest3.toString().contains("refresh_token"), equalTo(true)); + // BWC incalidate access token 2 + assertThat(bulkRequests.get(4).requests().get(0), instanceOf(IndexRequest.class)); + assertThat(bulkRequests.get(4).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate2.v1().getId())); + // Invalidate access token 2 + assertThat(bulkRequests.get(5).requests().get(0), instanceOf(UpdateRequest.class)); + assertThat(bulkRequests.get(5).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId())); + UpdateRequest updateRequest4 = (UpdateRequest) bulkRequests.get(5).requests().get(0); + assertThat(updateRequest4.toString().contains("access_token"), equalTo(true)); } private Function findTokenByRefreshToken(SearchHit[] searchHits) { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java index 66d3233b07a..7dec105e1ee 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/action/saml/TransportSamlLogoutActionTests.java @@ -6,7 +6,11 @@ package org.elasticsearch.xpack.security.action.saml; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.bulk.BulkAction; +import org.elasticsearch.action.bulk.BulkItemResponse; +import org.elasticsearch.action.bulk.BulkRequest; +import org.elasticsearch.action.bulk.BulkRequestBuilder; +import org.elasticsearch.action.bulk.BulkResponse; import org.elasticsearch.action.get.GetAction; import org.elasticsearch.action.get.GetRequestBuilder; import org.elasticsearch.action.get.GetResponse; @@ -24,7 +28,6 @@ import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.action.update.UpdateAction; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; -import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.collect.MapBuilder; @@ -72,6 +75,9 @@ import java.util.function.Consumer; import static org.elasticsearch.xpack.core.security.authc.RealmSettings.getFullSettingKey; import static org.elasticsearch.xpack.security.authc.TokenServiceTests.mockGetTokenFromId; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.startsWith; import static org.mockito.Matchers.any; @@ -89,7 +95,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { private SamlRealm samlRealm; private TokenService tokenService; private List indexRequests; - private List updateRequests; + private List bulkRequests; private TransportSamlLogoutAction action; private Client client; @@ -112,7 +118,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { new Authentication(new User("kibana"), new Authentication.RealmRef("realm", "type", "node"), null).writeToContext(threadContext); indexRequests = new ArrayList<>(); - updateRequests = new ArrayList<>(); + bulkRequests = new ArrayList<>(); client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); when(client.settings()).thenReturn(settings); @@ -137,6 +143,10 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { .setId((String) invocationOnMock.getArguments()[2]); return builder; }).when(client).prepareUpdate(anyString(), anyString(), anyString()); + doAnswer(invocationOnMock -> { + BulkRequestBuilder builder = new BulkRequestBuilder(client, BulkAction.INSTANCE); + return builder; + }).when(client).prepareBulk(); when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); doAnswer(invocationOnMock -> { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -154,15 +164,6 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { listener.onResponse(response); return Void.TYPE; }).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class)); - doAnswer(invocationOnMock -> { - UpdateRequest updateRequest = (UpdateRequest) invocationOnMock.getArguments()[0]; - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - updateRequests.add(updateRequest); - final UpdateResponse response = new UpdateResponse( - updateRequest.getShardId(), updateRequest.type(), updateRequest.id(), 1, DocWriteResponse.Result.UPDATED); - listener.onResponse(response); - return Void.TYPE; - }).when(client).update(any(UpdateRequest.class), any(ActionListener.class)); doAnswer(invocationOnMock -> { IndexRequest indexRequest = (IndexRequest) invocationOnMock.getArguments()[0]; ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; @@ -181,6 +182,14 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { listener.onResponse(response); return Void.TYPE; }).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class)); + doAnswer(invocationOnMock -> { + BulkRequest bulkRequest = (BulkRequest) invocationOnMock.getArguments()[0]; + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + bulkRequests.add(bulkRequest); + final BulkResponse response = new BulkResponse(new BulkItemResponse[0], 1); + listener.onResponse(response); + return Void.TYPE; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); final SecurityIndexManager securityIndex = mock(SecurityIndexManager.class); doAnswer(inv -> { @@ -247,9 +256,17 @@ public class TransportSamlLogoutActionTests extends SamlTestCase { assertThat(indexRequest1, notNullValue()); assertThat(indexRequest1.id(), startsWith("token")); - final IndexRequest indexRequest2 = indexRequests.get(1); - assertThat(indexRequest2, notNullValue()); - assertThat(indexRequest2.id(), startsWith("invalidated-token")); + assertThat(bulkRequests.size(), equalTo(2)); + final BulkRequest bulkRequest1 = bulkRequests.get(0); + assertThat(bulkRequest1.requests().size(), equalTo(1)); + assertThat(bulkRequest1.requests().get(0), instanceOf(IndexRequest.class)); + assertThat(bulkRequest1.requests().get(0).id(), startsWith("invalidated-token_")); + + final BulkRequest bulkRequest2 = bulkRequests.get(1); + assertThat(bulkRequest2.requests().size(), equalTo(1)); + assertThat(bulkRequest2.requests().get(0), instanceOf(UpdateRequest.class)); + assertThat(bulkRequest2.requests().get(0).id(), startsWith("token_")); + assertThat(bulkRequest2.requests().get(0).toString(), containsString("\"access_token\":{\"invalidated\":true")); } } diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index c4efdc16e10..968c17f556b 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -144,7 +144,9 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { .prepareInvalidateToken(response.getTokenString()) .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) .get(); - assertTrue(invalidateResponse.isCreated()); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1)); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); AtomicReference docId = new AtomicReference<>(); assertBusy(() -> { SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME) @@ -189,6 +191,72 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { }, 30, TimeUnit.SECONDS); } + public void testInvalidateAllTokensForUser() throws Exception{ + final int numOfRequests = randomIntBetween(5, 10); + for (int i = 0; i < numOfRequests; i++) { + securityClient().prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) + .get(); + } + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClientSuperuser = new SecurityClient(client); + InvalidateTokenResponse invalidateResponse = securityClientSuperuser + .prepareInvalidateToken() + .setUserName(SecuritySettingsSource.TEST_USER_NAME) + .get(); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(2 * (numOfRequests))); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); + } + + public void testInvalidateAllTokensForRealm() throws Exception{ + final int numOfRequests = randomIntBetween(5, 10); + for (int i = 0; i < numOfRequests; i++) { + securityClient().prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) + .get(); + } + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClientSuperuser = new SecurityClient(client); + InvalidateTokenResponse invalidateResponse = securityClientSuperuser + .prepareInvalidateToken() + .setRealmName("file") + .get(); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(2 * (numOfRequests))); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); + } + + public void testInvalidateAllTokensForRealmThatHasNone() { + final int numOfRequests = randomIntBetween(2, 4); + for (int i = 0; i < numOfRequests; i++) { + securityClient().prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray())) + .get(); + } + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER, + SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClientSuperuser = new SecurityClient(client); + InvalidateTokenResponse invalidateResponse = securityClientSuperuser + .prepareInvalidateToken() + .setRealmName("saml") + .get(); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); + } + public void testExpireMultipleTimes() { CreateTokenResponse response = securityClient().prepareCreateToken() .setGrantType("password") @@ -200,12 +268,16 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { .prepareInvalidateToken(response.getTokenString()) .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) .get(); - assertTrue(invalidateResponse.isCreated()); - assertFalse(securityClient() - .prepareInvalidateToken(response.getTokenString()) - .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) - .get() - .isCreated()); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1)); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); + InvalidateTokenResponse invalidateAgainResponse = securityClient() + .prepareInvalidateToken(response.getTokenString()) + .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) + .get(); + assertThat(invalidateAgainResponse.getResult().getInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateAgainResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(1)); + assertThat(invalidateAgainResponse.getResult().getErrors().size(), equalTo(0)); } public void testRefreshingToken() { @@ -248,7 +320,9 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { .prepareInvalidateToken(createTokenResponse.getRefreshToken()) .setType(InvalidateTokenRequest.Type.REFRESH_TOKEN) .get(); - assertTrue(invalidateResponse.isCreated()); + assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1)); + assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0)); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); @@ -362,9 +436,11 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { // invalidate PlainActionFuture invalidateResponseFuture = new PlainActionFuture<>(); InvalidateTokenRequest invalidateTokenRequest = - new InvalidateTokenRequest(createTokenResponse.getTokenString(), InvalidateTokenRequest.Type.ACCESS_TOKEN); + new InvalidateTokenRequest(createTokenResponse.getTokenString(), InvalidateTokenRequest.Type.ACCESS_TOKEN.getValue()); securityClient.invalidateToken(invalidateTokenRequest, invalidateResponseFuture); - assertTrue(invalidateResponseFuture.get().isCreated()); + assertThat(invalidateResponseFuture.get().getResult().getInvalidatedTokens().size(), equalTo(1)); + assertThat(invalidateResponseFuture.get().getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0)); + assertThat(invalidateResponseFuture.get().getResult().getErrors().size(), equalTo(0)); ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> { PlainActionFuture responseFuture = new PlainActionFuture<>(); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 7926b44a38c..286f07667ec 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -48,6 +48,7 @@ import org.elasticsearch.xpack.core.XPackSettings; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef; import org.elasticsearch.xpack.core.security.authc.TokenMetaData; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; import org.elasticsearch.xpack.core.security.user.User; import org.elasticsearch.xpack.core.watcher.watch.ClockMock; import org.elasticsearch.xpack.security.support.SecurityIndexManager; @@ -523,7 +524,7 @@ public class TokenServiceTests extends ESTestCase { assertNull(future.get()); e = expectThrows(IllegalStateException.class, () -> { - PlainActionFuture invalidateFuture = new PlainActionFuture<>(); + PlainActionFuture invalidateFuture = new PlainActionFuture<>(); tokenService.invalidateAccessToken((String) null, invalidateFuture); invalidateFuture.actionGet(); }); diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java new file mode 100644 index 00000000000..06c9411d0bc --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/support/TokensInvalidationResultTests.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ + +package org.elasticsearch.xpack.security.authc.support; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.ToXContent; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.json.JsonXContent; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult; + +import java.util.Arrays; +import java.util.Collections; + +import static org.hamcrest.Matchers.equalTo; + +public class TokensInvalidationResultTests extends ESTestCase { + + public void testToXcontent() throws Exception{ + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), + Arrays.asList("token3", "token4"), + Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")), + new ElasticsearchException("boo", new IllegalStateException("far"))), + randomIntBetween(0, 5)); + + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + result.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertThat(Strings.toString(builder), + equalTo( + "{\"created\":false," + + "\"invalidated_tokens\":2," + + "\"previously_invalidated_tokens\":2," + + "\"error_count\":2," + + "\"error_details\":[" + + "{\"type\":\"exception\"," + + "\"reason\":\"foo\"," + + "\"caused_by\":{" + + "\"type\":\"illegal_state_exception\"," + + "\"reason\":\"bar\"" + + "}" + + "}," + + "{\"type\":\"exception\"," + + "\"reason\":\"boo\"," + + "\"caused_by\":{" + + "\"type\":\"illegal_state_exception\"," + + "\"reason\":\"far\"" + + "}" + + "}" + + "]" + + "}")); + } + } + + public void testToXcontentWithNoErrors() throws Exception{ + TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"), + Collections.emptyList(), + Collections.emptyList(), randomIntBetween(0, 5)); + try (XContentBuilder builder = JsonXContent.contentBuilder()) { + result.toXContent(builder, ToXContent.EMPTY_PARAMS); + assertThat(Strings.toString(builder), + equalTo( + "{\"created\":true," + + "\"invalidated_tokens\":2," + + "\"previously_invalidated_tokens\":0," + + "\"error_count\":0" + + "}")); + } + } +} diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenActionTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenActionTests.java new file mode 100644 index 00000000000..00850ba6e5a --- /dev/null +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenActionTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security.rest.action.oauth2; + +import org.elasticsearch.common.xcontent.DeprecationHandler; +import org.elasticsearch.common.xcontent.NamedXContentRegistry; +import org.elasticsearch.common.xcontent.XContentParser; +import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest; + +import static org.hamcrest.Matchers.containsString; + +public class RestInvalidateTokenActionTests extends ESTestCase { + + public void testParserForUserAndRealm() throws Exception { + final String request = "{" + + "\"username\": \"user1\"," + + "\"realm_name\": \"realm1\"" + + "}"; + try (XContentParser parser = XContentType.JSON.xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) { + InvalidateTokenRequest invalidateTokenRequest = RestInvalidateTokenAction.PARSER.parse(parser, null); + assertEquals("user1", invalidateTokenRequest.getUserName()); + assertEquals("realm1", invalidateTokenRequest.getRealmName()); + assertNull(invalidateTokenRequest.getTokenString()); + assertNull(invalidateTokenRequest.getTokenType()); + } + } + + public void testParserForToken() throws Exception { + final String request = "{" + + "\"refresh_token\": \"refresh_token_string\"" + + "}"; + try (XContentParser parser = XContentType.JSON.xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) { + InvalidateTokenRequest invalidateTokenRequest = RestInvalidateTokenAction.PARSER.parse(parser, null); + assertEquals("refresh_token_string", invalidateTokenRequest.getTokenString()); + assertEquals("refresh_token", invalidateTokenRequest.getTokenType().getValue()); + assertNull(invalidateTokenRequest.getRealmName()); + assertNull(invalidateTokenRequest.getUserName()); + } + } + + public void testParserForIncorrectInput() throws Exception { + final String request = "{" + + "\"refresh_token\": \"refresh_token_string\"," + + "\"token\": \"access_token_string\"" + + "}"; + try (XContentParser parser = XContentType.JSON.xContent() + .createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestInvalidateTokenAction.PARSER.parse(parser, + null)); + assertThat(e.getCause().getMessage(), containsString("only one of [token, refresh_token] may be sent per request")); + + } + } +} diff --git a/x-pack/plugin/src/test/resources/rest-api-spec/test/token/10_basic.yml b/x-pack/plugin/src/test/resources/rest-api-spec/test/token/10_basic.yml index 43f25a11db0..81389ac8524 100644 --- a/x-pack/plugin/src/test/resources/rest-api-spec/test/token/10_basic.yml +++ b/x-pack/plugin/src/test/resources/rest-api-spec/test/token/10_basic.yml @@ -5,7 +5,7 @@ setup: - do: cluster.health: - wait_for_status: yellow + wait_for_status: yellow - do: security.put_user: @@ -79,7 +79,93 @@ teardown: body: token: $token - - match: { created: true } + - match: { created: true} + - match: { invalidated_tokens: 1 } + - match: { previously_invalidated_tokens: 0 } + - match: { error_count: 0 } + + - do: + catch: unauthorized + headers: + Authorization: Bearer ${token} + security.authenticate: {} + +--- +"Test invalidate user's tokens": + + - do: + security.get_token: + body: + grant_type: "password" + username: "token_user" + password: "x-pack-test-password" + + - match: { type: "Bearer" } + - is_true: access_token + - set: { access_token: token } + - match: { expires_in: 1200 } + - is_false: scope + + - do: + headers: + Authorization: Bearer ${token} + security.authenticate: {} + + - match: { username: "token_user" } + - match: { roles.0: "superuser" } + - match: { full_name: "Token User" } + + - do: + security.invalidate_token: + body: + username: "token_user" + + - match: { created: true} + - match: { invalidated_tokens: 2 } + - match: { previously_invalidated_tokens: 0 } + - match: { error_count: 0 } + + - do: + catch: unauthorized + headers: + Authorization: Bearer ${token} + security.authenticate: {} + + +--- +"Test invalidate realm's tokens": + + - do: + security.get_token: + body: + grant_type: "password" + username: "token_user" + password: "x-pack-test-password" + + - match: { type: "Bearer" } + - is_true: access_token + - set: { access_token: token } + - match: { expires_in: 1200 } + - is_false: scope + + - do: + headers: + Authorization: Bearer ${token} + security.authenticate: {} + + - match: { username: "token_user" } + - match: { roles.0: "superuser" } + - match: { full_name: "Token User" } + + - do: + security.invalidate_token: + body: + realm_name: "default_native" + + - match: { created: true} + - match: { invalidated_tokens: 2 } + - match: { previously_invalidated_tokens: 0 } + - match: { error_count: 0 } - do: catch: unauthorized