Enhance Invalidate Token API (#35388)

This change:

- Adds functionality to invalidate all (refresh+access) tokens for all users of a realm
- Adds functionality to invalidate all (refresh+access)tokens for a user in all realms
- Adds functionality to invalidate all (refresh+access) tokens for a user in a specific realm
- Changes the response format for the invalidate token API to contain information about the 
   number of the invalidated tokens and possible errors that were encountered.
- Updates the API Documentation

After back-porting to 6.x, the `created` field will be removed from master as a field in the 
response

Resolves: #35115
Relates: #34556
This commit is contained in:
Ioannis Kakavas 2018-12-18 10:05:50 +02:00 committed by GitHub
parent 96d279ed83
commit 7b9ca62174
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1429 additions and 309 deletions

View File

@ -1317,6 +1317,7 @@ public class SecurityDocumentationIT extends ESRestHighLevelClientTestCase {
}
}
@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/pull/36362")
public void testInvalidateToken() throws Exception {
RestHighLevelClient client = highLevelClient();

View File

@ -36,4 +36,4 @@ The returned +{response}+ contains a single property:
["source","java",subs="attributes,callouts,macros"]
--------------------------------------------------
include-tagged::{doc-tests-file}[{api}-response]
--------------------------------------------------
--------------------------------------------------

View File

@ -2,7 +2,7 @@
[[security-api-invalidate-token]]
=== Invalidate token API
Invalidates an access token or a refresh token.
Invalidates one or more access tokens or refresh tokens.
==== Request
@ -19,21 +19,31 @@ can no longer be used. That time period is defined by the
The refresh tokens returned by the <<security-api-get-token,get token API>> are
only valid for 24 hours. They can also be used exactly once.
If you want to invalidate an access or refresh token immediately, use this invalidate token API.
If you want to invalidate one or more access or refresh tokens immediately, use this invalidate token API.
==== Request Body
The following parameters can be specified in the body of a DELETE request and
pertain to invalidating a token:
pertain to invalidating tokens:
`token` (optional)::
(string) An access token. This parameter cannot be used when `refresh_token` is used.
(string) An access token. This parameter cannot be used any of `refresh_token`, `realm_name` or
`username` are used.
`refresh_token` (optional)::
(string) A refresh token. This parameter cannot be used when `token` is used.
(string) A refresh token. This parameter cannot be used any of `refresh_token`, `realm_name` or
`username` are used.
NOTE: One of `token` or `refresh_token` parameters is required.
`realm_name` (optional)::
(string) The name of an authentication realm. This parameter cannot be used with either `refresh_token` or `token`.
`username` (optional)::
(string) The username of a user. This parameter cannot be used with either `refresh_token` or `token`
NOTE: While all parameters are optional, at least one of them is required. More specifically, either one of `token`
or `refresh_token` parameters is required. If none of these two are specified, then `realm_name` and/or `username`
need to be specified.
==== Examples
@ -59,15 +69,75 @@ DELETE /_security/oauth2/token
--------------------------------------------------
// NOTCONSOLE
A successful call returns a JSON structure that indicates whether the token
has already been invalidated.
The following example invalidates all access tokens and refresh tokens for the `saml1` realm immediately:
[source,js]
--------------------------------------------------
DELETE /_xpack/security/oauth2/token
{
"created" : true <1>
"realm_name" : "saml1"
}
--------------------------------------------------
// NOTCONSOLE
<1> When a token has already been invalidated, `created` is set to false.
The following example invalidates all access tokens and refresh tokens for the user `myuser` in all realms immediately:
[source,js]
--------------------------------------------------
DELETE /_xpack/security/oauth2/token
{
"username" : "myuser"
}
--------------------------------------------------
// NOTCONSOLE
Finally, the following example invalidates all access tokens and refresh tokens for the user `myuser` in
the `saml1` realm immediately:
[source,js]
--------------------------------------------------
DELETE /_xpack/security/oauth2/token
{
"username" : "myuser",
"realm_name" : "saml1"
}
--------------------------------------------------
// NOTCONSOLE
A successful call returns a JSON structure that contains the number of tokens that were invalidated, the number
of tokens that had already been invalidated, and potentially a list of errors encountered while invalidating
specific tokens.
[source,js]
--------------------------------------------------
{
"invalidated_tokens":9, <1>
"previously_invalidated_tokens":15, <2>
"error_count":2, <3>
"error_details":[ <4>
{
"type":"exception",
"reason":"Elasticsearch exception [type=exception, reason=foo]",
"caused_by":{
"type":"exception",
"reason":"Elasticsearch exception [type=illegal_argument_exception, reason=bar]"
}
},
{
"type":"exception",
"reason":"Elasticsearch exception [type=exception, reason=boo]",
"caused_by":{
"type":"exception",
"reason":"Elasticsearch exception [type=illegal_argument_exception, reason=far]"
}
}
]
}
--------------------------------------------------
// NOTCONSOLE
<1> The number of the tokens that were invalidated as part of this request.
<2> The number of tokens that were already invalidated.
<3> The number of errors that were encountered when invalidating the tokens.
<4> Details about these errors. This field is not present in the response when
`error_count` is 0.

View File

@ -8,7 +8,7 @@ package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.action.Action;
/**
* Action for invalidating a given token
* Action for invalidating one or more tokens
*/
public final class InvalidateTokenAction extends Action<InvalidateTokenResponse> {

View File

@ -8,6 +8,7 @@ package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
@ -22,31 +23,81 @@ import static org.elasticsearch.action.ValidateActions.addValidationError;
public final class InvalidateTokenRequest extends ActionRequest {
public enum Type {
ACCESS_TOKEN,
REFRESH_TOKEN
ACCESS_TOKEN("token"),
REFRESH_TOKEN("refresh_token");
private final String value;
Type(String value) {
this.value = value;
}
public String getValue() {
return value;
}
public static Type fromString(String tokenType) {
if (tokenType != null) {
for (Type type : values()) {
if (type.getValue().equals(tokenType)) {
return type;
}
}
}
return null;
}
}
private String tokenString;
private Type tokenType;
private String realmName;
private String userName;
public InvalidateTokenRequest() {}
/**
* @param tokenString the string representation of the token
* @param tokenString the string representation of the token to be invalidated
* @param tokenType the type of the token to be invalidated
* @param realmName the name of the realm for which all tokens will be invalidated
* @param userName the principal of the user for which all tokens will be invalidated
*/
public InvalidateTokenRequest(String tokenString, Type type) {
public InvalidateTokenRequest(@Nullable String tokenString, @Nullable String tokenType,
@Nullable String realmName, @Nullable String userName) {
this.tokenString = tokenString;
this.tokenType = type;
this.tokenType = Type.fromString(tokenType);
this.realmName = realmName;
this.userName = userName;
}
/**
* @param tokenString the string representation of the token to be invalidated
* @param tokenType the type of the token to be invalidated
*/
public InvalidateTokenRequest(String tokenString, String tokenType) {
this.tokenString = tokenString;
this.tokenType = Type.fromString(tokenType);
this.realmName = null;
this.userName = null;
}
@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException validationException = null;
if (Strings.isNullOrEmpty(tokenString)) {
validationException = addValidationError("token string must be provided", null);
}
if (tokenType == null) {
validationException = addValidationError("token type must be provided", validationException);
if (Strings.hasText(realmName) || Strings.hasText(userName)) {
if (Strings.hasText(tokenString)) {
validationException =
addValidationError("token string must not be provided when realm name or username is specified", null);
}
if (tokenType != null) {
validationException =
addValidationError("token type must not be provided when realm name or username is specified", validationException);
}
} else if (Strings.isNullOrEmpty(tokenString)) {
validationException =
addValidationError("token string must be provided when not specifying a realm name or a username", null);
} else if (tokenType == null) {
validationException =
addValidationError("token type must be provided when a token string is specified", null);
}
return validationException;
}
@ -67,26 +118,76 @@ public final class InvalidateTokenRequest extends ActionRequest {
this.tokenType = tokenType;
}
public String getRealmName() {
return realmName;
}
public void setRealmName(String realmName) {
this.realmName = realmName;
}
public String getUserName() {
return userName;
}
public void setUserName(String userName) {
this.userName = userName;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(tokenString);
if (out.getVersion().before(Version.V_7_0_0)) {
if (Strings.isNullOrEmpty(tokenString)) {
throw new IllegalArgumentException("token is required for versions < v6.6.0");
}
out.writeString(tokenString);
} else {
out.writeOptionalString(tokenString);
}
if (out.getVersion().onOrAfter(Version.V_6_2_0)) {
out.writeVInt(tokenType.ordinal());
if (out.getVersion().before(Version.V_7_0_0)) {
if (tokenType == null) {
throw new IllegalArgumentException("token type is not optional for versions > v6.2.0 and < v6.6.0");
}
out.writeVInt(tokenType.ordinal());
} else {
out.writeOptionalVInt(tokenType == null ? null : tokenType.ordinal());
}
} else if (tokenType == Type.REFRESH_TOKEN) {
throw new IllegalArgumentException("refresh token invalidation cannot be serialized with version [" + out.getVersion() +
"]");
throw new IllegalArgumentException("refresh token invalidation cannot be serialized with version [" + out.getVersion() + "]");
}
if (out.getVersion().onOrAfter(Version.V_7_0_0)) {
out.writeOptionalString(realmName);
out.writeOptionalString(userName);
} else if (realmName != null || userName != null) {
throw new IllegalArgumentException(
"realm or user token invalidation cannot be serialized with version [" + out.getVersion() + "]");
}
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
tokenString = in.readString();
if (in.getVersion().before(Version.V_7_0_0)) {
tokenString = in.readString();
} else {
tokenString = in.readOptionalString();
}
if (in.getVersion().onOrAfter(Version.V_6_2_0)) {
tokenType = Type.values()[in.readVInt()];
if (in.getVersion().before(Version.V_7_0_0)) {
int type = in.readVInt();
tokenType = Type.values()[type];
} else {
Integer type = in.readOptionalVInt();
tokenType = type == null ? null : Type.values()[type];
}
} else {
tokenType = Type.ACCESS_TOKEN;
}
if (in.getVersion().onOrAfter(Version.V_7_0_0)) {
realmName = in.readOptionalString();
userName = in.readOptionalString();
}
}
}

View File

@ -34,4 +34,20 @@ public final class InvalidateTokenRequestBuilder
request.setTokenType(type);
return this;
}
/**
* Sets the name of the realm for which all tokens should be invalidated
*/
public InvalidateTokenRequestBuilder setRealmName(String realmName) {
request.setRealmName(realmName);
return this;
}
/**
* Sets the username for which all tokens should be invalidated
*/
public InvalidateTokenRequestBuilder setUserName(String username) {
request.setUserName(username);
return this;
}
}

View File

@ -5,41 +5,83 @@
*/
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.Objects;
/**
* Response for a invalidation of a token.
* Response for a invalidation of one or multiple tokens.
*/
public final class InvalidateTokenResponse extends ActionResponse {
public final class InvalidateTokenResponse extends ActionResponse implements ToXContent {
private boolean created;
private TokensInvalidationResult result;
public InvalidateTokenResponse() {}
public InvalidateTokenResponse(boolean created) {
this.created = created;
public InvalidateTokenResponse(TokensInvalidationResult result) {
this.result = result;
}
/**
* If the token is already invalidated then created will be <code>false</code>
*/
public boolean isCreated() {
return created;
public TokensInvalidationResult getResult() {
return result;
}
private boolean isCreated() {
return result.getInvalidatedTokens().size() > 0
&& result.getPreviouslyInvalidatedTokens().isEmpty()
&& result.getErrors().isEmpty();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeBoolean(created);
if (out.getVersion().before(Version.V_7_0_0)) {
out.writeBoolean(isCreated());
} else {
result.writeTo(out);
}
}
@Override
public void readFrom(StreamInput in) throws IOException {
super.readFrom(in);
created = in.readBoolean();
if (in.getVersion().before(Version.V_7_0_0)) {
final boolean created = in.readBoolean();
if (created) {
result = new TokensInvalidationResult(Arrays.asList(""), Collections.emptyList(), Collections.emptyList(), 0);
} else {
result = new TokensInvalidationResult(Collections.emptyList(), Arrays.asList(""), Collections.emptyList(), 0);
}
} else {
result = new TokensInvalidationResult(in);
}
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException {
result.toXContent(builder, params);
return builder;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
InvalidateTokenResponse that = (InvalidateTokenResponse) o;
return Objects.equals(result, that.result);
}
@Override
public int hashCode() {
return Objects.hash(result);
}
}

View File

@ -0,0 +1,113 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.security.authc.support;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.common.xcontent.ToXContentObject;
import org.elasticsearch.common.xcontent.XContentBuilder;
import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* The result of attempting to invalidate one or multiple tokens. The result contains information about:
* <ul>
* <li>how many of the tokens were actually invalidated</li>
* <li>how many tokens are not invalidated in this request because they were already invalidated</li>
* <li>how many errors were encountered while invalidating tokens and the error details</li>
* </ul>
*/
public class TokensInvalidationResult implements ToXContentObject, Writeable {
private final List<String> invalidatedTokens;
private final List<String> previouslyInvalidatedTokens;
private final List<ElasticsearchException> errors;
private final int attemptCount;
public TokensInvalidationResult(List<String> invalidatedTokens, List<String> previouslyInvalidatedTokens,
@Nullable List<ElasticsearchException> errors, int attemptCount) {
Objects.requireNonNull(invalidatedTokens, "invalidated_tokens must be provided");
this.invalidatedTokens = invalidatedTokens;
Objects.requireNonNull(previouslyInvalidatedTokens, "previously_invalidated_tokens must be provided");
this.previouslyInvalidatedTokens = previouslyInvalidatedTokens;
if (null != errors) {
this.errors = errors;
} else {
this.errors = Collections.emptyList();
}
this.attemptCount = attemptCount;
}
public TokensInvalidationResult(StreamInput in) throws IOException {
this.invalidatedTokens = in.readList(StreamInput::readString);
this.previouslyInvalidatedTokens = in.readList(StreamInput::readString);
this.errors = in.readList(StreamInput::readException);
this.attemptCount = in.readVInt();
}
public static TokensInvalidationResult emptyResult() {
return new TokensInvalidationResult(Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), 0);
}
public List<String> getInvalidatedTokens() {
return invalidatedTokens;
}
public List<String> getPreviouslyInvalidatedTokens() {
return previouslyInvalidatedTokens;
}
public List<ElasticsearchException> getErrors() {
return errors;
}
public int getAttemptCount() {
return attemptCount;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject()
//Remove created after PR is backported to 6.x
.field("created", isCreated())
.field("invalidated_tokens", invalidatedTokens.size())
.field("previously_invalidated_tokens", previouslyInvalidatedTokens.size())
.field("error_count", errors.size());
if (errors.isEmpty() == false) {
builder.field("error_details");
builder.startArray();
for (ElasticsearchException e : errors) {
builder.startObject();
ElasticsearchException.generateThrowableXContent(builder, params, e);
builder.endObject();
}
builder.endArray();
}
return builder.endObject();
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeStringList(invalidatedTokens);
out.writeStringList(previouslyInvalidatedTokens);
out.writeCollection(errors, StreamOutput::writeException);
out.writeVInt(attemptCount);
}
private boolean isCreated() {
return this.getInvalidatedTokens().size() > 0
&& this.getPreviouslyInvalidatedTokens().isEmpty()
&& this.getErrors().isEmpty();
}
}

View File

@ -326,6 +326,10 @@ public class SecurityClient {
return new InvalidateTokenRequestBuilder(client).setTokenString(token);
}
public InvalidateTokenRequestBuilder prepareInvalidateToken() {
return new InvalidateTokenRequestBuilder(client);
}
public void invalidateToken(InvalidateTokenRequest request, ActionListener<InvalidateTokenResponse> listener) {
client.execute(InvalidateTokenAction.INSTANCE, request, listener);
}

View File

@ -8,7 +8,6 @@ package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.security.action.token.CreateTokenRequest;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.hasItem;

View File

@ -0,0 +1,82 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.test.ESTestCase;
import static org.hamcrest.Matchers.containsString;
public class InvalidateTokenRequestTests extends ESTestCase {
public void testValidation() {
InvalidateTokenRequest request = new InvalidateTokenRequest();
ActionRequestValidationException ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0), containsString("token string must be provided when not specifying a realm"));
request = new InvalidateTokenRequest(randomAlphaOfLength(12), randomFrom("", null));
ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0), containsString("token type must be provided when a token string is specified"));
request = new InvalidateTokenRequest(randomFrom("", null), "access_token");
ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0), containsString("token string must be provided when not specifying a realm"));
request = new InvalidateTokenRequest(randomFrom("", null), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8));
ve = request.validate();
assertNull(ve);
request =
new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8));
ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0),
containsString("token string must not be provided when realm name or username is specified"));
request = new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("token", "refresh_token"),
randomAlphaOfLength(4), randomAlphaOfLength(8));
ve = request.validate();
assertNotNull(ve);
assertEquals(2, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0),
containsString("token string must not be provided when realm name or username is specified"));
assertThat(ve.validationErrors().get(1),
containsString("token type must not be provided when realm name or username is specified"));
request =
new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomAlphaOfLength(4), randomAlphaOfLength(8));
ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0),
containsString("token string must not be provided when realm name or username is specified"));
request =
new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("token", "refresh_token"), randomFrom("", null),
randomAlphaOfLength(8));
ve = request.validate();
assertNotNull(ve);
assertEquals(2, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0),
containsString("token string must not be provided when realm name or username is specified"));
assertThat(ve.validationErrors().get(1),
containsString("token type must not be provided when realm name or username is specified"));
request = new InvalidateTokenRequest(randomAlphaOfLength(4), randomFrom("", null), randomFrom("", null), randomAlphaOfLength(8));
ve = request.validate();
assertNotNull(ve);
assertEquals(1, ve.validationErrors().size());
assertThat(ve.validationErrors().get(0),
containsString("token string must not be provided when realm name or username is specified"));
}
}

View File

@ -0,0 +1,141 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.core.security.action.token;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.VersionUtils;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
public class InvalidateTokenResponseTests extends ESTestCase {
public void testSerialization() throws IOException {
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
InvalidateTokenResponse serialized = new InvalidateTokenResponse();
serialized.readFrom(input);
assertThat(serialized.getResult().getInvalidatedTokens(), equalTo(response.getResult().getInvalidatedTokens()));
assertThat(serialized.getResult().getPreviouslyInvalidatedTokens(),
equalTo(response.getResult().getPreviouslyInvalidatedTokens()));
assertThat(serialized.getResult().getErrors().size(), equalTo(response.getResult().getErrors().size()));
assertThat(serialized.getResult().getErrors().get(0).toString(), containsString("this is an error message"));
assertThat(serialized.getResult().getErrors().get(1).toString(), containsString("this is an error message2"));
}
}
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false)),
Arrays.asList(generateRandomStringArray(20, 15, false)),
Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
InvalidateTokenResponse serialized = new InvalidateTokenResponse();
serialized.readFrom(input);
assertThat(serialized.getResult().getInvalidatedTokens(), equalTo(response.getResult().getInvalidatedTokens()));
assertThat(serialized.getResult().getPreviouslyInvalidatedTokens(),
equalTo(response.getResult().getPreviouslyInvalidatedTokens()));
assertThat(serialized.getResult().getErrors().size(), equalTo(response.getResult().getErrors().size()));
}
}
}
public void testSerializationToPre66Version() throws IOException{
final Version version = VersionUtils.randomVersionBetween(random(), Version.V_6_2_0, Version.V_6_5_1);
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
// False as we have errors and previously invalidated tokens
assertThat(input.readBoolean(), equalTo(false));
}
}
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
// False as we have previously invalidated tokens
assertThat(input.readBoolean(), equalTo(false));
}
}
result = new TokensInvalidationResult(Arrays.asList(generateRandomStringArray(20, 15, false, false)),
Collections.emptyList(), Collections.emptyList(), randomIntBetween(0, 5));
response = new InvalidateTokenResponse(result);
try (BytesStreamOutput output = new BytesStreamOutput()) {
output.setVersion(version);
response.writeTo(output);
try (StreamInput input = output.bytes().streamInput()) {
assertThat(input.readBoolean(), equalTo(true));
}
}
}
public void testToXContent() throws IOException {
List invalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
List previouslyInvalidatedTokens = Arrays.asList(generateRandomStringArray(20, 15, false));
TokensInvalidationResult result = new TokensInvalidationResult(invalidatedTokens, previouslyInvalidatedTokens,
Arrays.asList(new ElasticsearchException("foo", new IllegalArgumentException("this is an error message")),
new ElasticsearchException("bar", new IllegalArgumentException("this is an error message2"))),
randomIntBetween(0, 5));
InvalidateTokenResponse response = new InvalidateTokenResponse(result);
XContentBuilder builder = XContentFactory.jsonBuilder();
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo("{\"created\":false," +
"\"invalidated_tokens\":" + invalidatedTokens.size() + "," +
"\"previously_invalidated_tokens\":" + previouslyInvalidatedTokens.size() + "," +
"\"error_count\":2," +
"\"error_details\":[" +
"{\"type\":\"exception\"," +
"\"reason\":\"foo\"," +
"\"caused_by\":{" +
"\"type\":\"illegal_argument_exception\"," +
"\"reason\":\"this is an error message\"}" +
"}," +
"{\"type\":\"exception\"," +
"\"reason\":\"bar\"," +
"\"caused_by\":" +
"{\"type\":\"illegal_argument_exception\"," +
"\"reason\":\"this is an error message2\"}" +
"}" +
"]" +
"}"));
}
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionAction;
import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionRequest;
import org.elasticsearch.xpack.core.security.action.saml.SamlInvalidateSessionResponse;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.authc.TokenService;
import org.elasticsearch.xpack.security.authc.UserToken;
@ -27,12 +28,11 @@ import org.elasticsearch.xpack.security.authc.saml.SamlRedirect;
import org.elasticsearch.xpack.security.authc.saml.SamlUtils;
import org.opensaml.saml.saml2.core.LogoutResponse;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.function.Predicate;
import static org.elasticsearch.xpack.security.authc.saml.SamlRealm.findSamlRealms;
@ -85,7 +85,7 @@ public final class TransportSamlInvalidateSessionAction
private void findAndInvalidateTokens(SamlRealm realm, SamlLogoutRequestHandler.Result result, ActionListener<Integer> listener) {
final Map<String, Object> tokenMetadata = realm.createTokenMetadata(result.getNameId(), result.getSession());
if (Strings.hasText((String) tokenMetadata.get(SamlRealm.TOKEN_METADATA_NAMEID_VALUE)) == false) {
if (Strings.isNullOrEmpty((String) tokenMetadata.get(SamlRealm.TOKEN_METADATA_NAMEID_VALUE))) {
// If we don't have a valid name-id to match against, don't do anything
logger.debug("Logout request [{}] has no NameID value, so cannot invalidate any sessions", result);
listener.onResponse(0);
@ -93,22 +93,21 @@ public final class TransportSamlInvalidateSessionAction
}
tokenService.findActiveTokensForRealm(realm.name(), ActionListener.wrap(tokens -> {
List<Tuple<UserToken, String>> sessionTokens = filterTokens(tokens, tokenMetadata);
logger.debug("Found [{}] token pairs to invalidate for SAML metadata [{}]", sessionTokens.size(), tokenMetadata);
if (sessionTokens.isEmpty()) {
listener.onResponse(0);
} else {
GroupedActionListener<Boolean> groupedListener = new GroupedActionListener<>(
ActionListener.wrap(collection -> listener.onResponse(collection.size()), listener::onFailure),
sessionTokens.size(), Collections.emptyList()
);
sessionTokens.forEach(tuple -> invalidateTokenPair(tuple, groupedListener));
}
}, e -> listener.onFailure(e)
));
logger.debug("Found [{}] token pairs to invalidate for SAML metadata [{}]", tokens.size(), tokenMetadata);
if (tokens.isEmpty()) {
listener.onResponse(0);
} else {
GroupedActionListener<TokensInvalidationResult> groupedListener = new GroupedActionListener<>(
ActionListener.wrap(collection -> listener.onResponse(collection.size()), listener::onFailure),
tokens.size(), Collections.emptyList()
);
tokens.forEach(tuple -> invalidateTokenPair(tuple, groupedListener));
}
}, listener::onFailure
), containsMetadata(tokenMetadata));
}
private void invalidateTokenPair(Tuple<UserToken, String> tokenPair, ActionListener<Boolean> listener) {
private void invalidateTokenPair(Tuple<UserToken, String> tokenPair, ActionListener<TokensInvalidationResult> listener) {
// Invalidate the refresh token first, so the client doesn't trigger a refresh once the access token is invalidated
tokenService.invalidateRefreshToken(tokenPair.v2(), ActionListener.wrap(ignore -> tokenService.invalidateAccessToken(
tokenPair.v1(),
@ -118,13 +117,12 @@ public final class TransportSamlInvalidateSessionAction
})), listener::onFailure));
}
private List<Tuple<UserToken, String>> filterTokens(Collection<Tuple<UserToken, String>> tokens, Map<String, Object> requiredMetadata) {
return tokens.stream()
.filter(tup -> {
Map<String, Object> actualMetadata = tup.v1().getMetadata();
return requiredMetadata.entrySet().stream().allMatch(e -> Objects.equals(actualMetadata.get(e.getKey()), e.getValue()));
})
.collect(Collectors.toList());
private Predicate<Map<String, Object>> containsMetadata(Map<String, Object> requiredMetadata) {
return source -> {
Map<String, Object> actualMetadata = (Map<String, Object>) source.get("metadata");
return requiredMetadata.entrySet().stream().allMatch(e -> Objects.equals(actualMetadata.get(e.getKey()), e.getValue()));
};
}
}

View File

@ -18,6 +18,7 @@ import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutRequest;
import org.elasticsearch.xpack.core.security.action.saml.SamlLogoutResponse;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.Realm;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.security.authc.Realms;
import org.elasticsearch.xpack.security.authc.TokenService;
@ -79,7 +80,7 @@ public final class TransportSamlLogoutAction
}, listener::onFailure));
}
private void invalidateRefreshToken(String refreshToken, ActionListener<Boolean> listener) {
private void invalidateRefreshToken(String refreshToken, ActionListener<TokensInvalidationResult> listener) {
if (refreshToken == null) {
listener.onResponse(null);
} else {

View File

@ -8,12 +8,14 @@ package org.elasticsearch.xpack.security.action.token;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.HandledTransportAction;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenAction;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenResponse;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.security.authc.TokenService;
/**
@ -31,9 +33,12 @@ public final class TransportInvalidateTokenAction extends HandledTransportAction
@Override
protected void doExecute(Task task, InvalidateTokenRequest request, ActionListener<InvalidateTokenResponse> listener) {
final ActionListener<Boolean> invalidateListener =
ActionListener.wrap(created -> listener.onResponse(new InvalidateTokenResponse(created)), listener::onFailure);
if (request.getTokenType() == InvalidateTokenRequest.Type.ACCESS_TOKEN) {
final ActionListener<TokensInvalidationResult> invalidateListener =
ActionListener.wrap(tokensInvalidationResult ->
listener.onResponse(new InvalidateTokenResponse(tokensInvalidationResult)), listener::onFailure);
if (Strings.hasText(request.getUserName()) || Strings.hasText(request.getRealmName())) {
tokenService.invalidateActiveTokensForRealmAndUser(request.getRealmName(), request.getUserName(), invalidateListener);
} else if (request.getTokenType() == InvalidateTokenRequest.Type.ACCESS_TOKEN) {
tokenService.invalidateAccessToken(request.getTokenString(), invalidateListener);
} else {
assert request.getTokenType() == InvalidateTokenRequest.Type.REFRESH_TOKEN;

View File

@ -17,6 +17,11 @@ import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteRequest.OpType;
import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetRequest;
import org.elasticsearch.action.get.GetResponse;
import org.elasticsearch.action.get.MultiGetItemResponse;
@ -24,7 +29,6 @@ import org.elasticsearch.action.get.MultiGetRequest;
import org.elasticsearch.action.get.MultiGetResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.support.WriteRequest.RefreshPolicy;
@ -39,6 +43,7 @@ import org.elasticsearch.cluster.ack.AckedRequest;
import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse;
import org.elasticsearch.cluster.metadata.MetaData;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Priority;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.UUIDs;
@ -61,7 +66,6 @@ import org.elasticsearch.common.util.iterable.Iterables;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.XContentFactory;
import org.elasticsearch.core.internal.io.IOUtils;
import org.elasticsearch.index.engine.DocumentMissingException;
import org.elasticsearch.index.engine.VersionConflictEngineException;
import org.elasticsearch.index.query.BoolQueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
@ -74,6 +78,7 @@ import org.elasticsearch.xpack.core.security.ScrollHelper;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.KeyAndTimestamp;
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.security.support.SecurityIndexManager;
import javax.crypto.Cipher;
@ -90,6 +95,7 @@ import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
@ -116,6 +122,8 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException;
import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK;
@ -221,9 +229,9 @@ public final class TokenService {
boolean includeRefreshToken) throws IOException {
ensureEnabled();
if (authentication == null) {
listener.onFailure(traceLog("create token", null, new IllegalArgumentException("authentication must be provided")));
listener.onFailure(traceLog("create token", new IllegalArgumentException("authentication must be provided")));
} else if (originatingClientAuth == null) {
listener.onFailure(traceLog("create token", null,
listener.onFailure(traceLog("create token",
new IllegalArgumentException("originating client authentication must be provided")));
} else {
final Instant created = clock.instant();
@ -471,7 +479,7 @@ public final class TokenService {
* have been created on versions on or after 6.2; this step involves performing an update to
* the token document and setting the <code>invalidated</code> field to <code>true</code>
*/
public void invalidateAccessToken(String tokenString, ActionListener<Boolean> listener) {
public void invalidateAccessToken(String tokenString, ActionListener<TokensInvalidationResult> listener) {
ensureEnabled();
if (Strings.isNullOrEmpty(tokenString)) {
logger.trace("No token-string provided");
@ -484,7 +492,8 @@ public final class TokenService {
listener.onFailure(traceLog("invalidate token", tokenString, malformedTokenException()));
} else {
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
indexBwcInvalidation(userToken, listener, new AtomicInteger(0), expirationEpochMilli);
indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0),
expirationEpochMilli, null);
}
}, listener::onFailure));
} catch (IOException e) {
@ -499,7 +508,7 @@ public final class TokenService {
*
* @see #invalidateAccessToken(String, ActionListener)
*/
public void invalidateAccessToken(UserToken userToken, ActionListener<Boolean> listener) {
public void invalidateAccessToken(UserToken userToken, ActionListener<TokensInvalidationResult> listener) {
ensureEnabled();
if (userToken == null) {
logger.trace("No access token provided");
@ -507,11 +516,17 @@ public final class TokenService {
} else {
maybeStartTokenRemover();
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
indexBwcInvalidation(userToken, listener, new AtomicInteger(0), expirationEpochMilli);
indexBwcInvalidation(Collections.singleton(userToken.getId()), listener, new AtomicInteger(0), expirationEpochMilli, null);
}
}
public void invalidateRefreshToken(String refreshToken, ActionListener<Boolean> listener) {
/**
* This method performs the steps necessary to invalidate a refresh token so that it may no longer be used.
*
* @param refreshToken The string representation of the refresh token
* @param listener the listener to notify upon completion
*/
public void invalidateRefreshToken(String refreshToken, ActionListener<TokensInvalidationResult> listener) {
ensureEnabled();
if (Strings.isNullOrEmpty(refreshToken)) {
logger.trace("No refresh token provided");
@ -520,152 +535,222 @@ public final class TokenService {
maybeStartTokenRemover();
findTokenFromRefreshToken(refreshToken,
ActionListener.wrap(tuple -> {
final String docId = tuple.v1().getHits().getAt(0).getId();
final long docVersion = tuple.v1().getHits().getAt(0).getVersion();
indexInvalidation(docId, Version.CURRENT, listener, tuple.v2(), "refresh_token", docVersion);
final String docId = getTokenIdFromDocumentId(tuple.v1().getHits().getAt(0).getId());
indexInvalidation(Collections.singletonList(docId), listener, tuple.v2(), "refresh_token", null);
}, listener::onFailure), new AtomicInteger(0));
}
}
/**
* Performs the actual bwc invalidation of a token and then kicks off the new invalidation method
* Invalidate all access tokens and all refresh tokens of a given {@code realmName} and/or of a given
* {@code username} so that they may no longer be used
*
* @param userToken the token to invalidate
* @param listener the listener to notify upon completion
* @param attemptCount the number of attempts to invalidate that have already been tried
* @param expirationEpochMilli the expiration time as milliseconds since the epoch
* @param realmName the realm of which the tokens should be invalidated
* @param username the username for which the tokens should be invalidated
* @param listener the listener to notify upon completion
*/
private void indexBwcInvalidation(UserToken userToken, ActionListener<Boolean> listener, AtomicInteger attemptCount,
long expirationEpochMilli) {
if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate token [{}] after [{}] attempts", userToken.getId(), attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate token"));
public void invalidateActiveTokensForRealmAndUser(@Nullable String realmName, @Nullable String username,
ActionListener<TokensInvalidationResult> listener) {
ensureEnabled();
if (Strings.isNullOrEmpty(realmName) && Strings.isNullOrEmpty(username)) {
logger.trace("No realm name or username provided");
listener.onFailure(new IllegalArgumentException("realm name or username must be provided"));
} else {
final String invalidatedTokenId = getInvalidatedTokenDocumentId(userToken);
IndexRequest indexRequest = client.prepareIndex(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId)
.setOpType(OpType.CREATE)
.setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli)
.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL)
.request();
final String tokenDocId = getTokenDocumentId(userToken);
final Version version = userToken.getVersion();
securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", tokenDocId, ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, indexRequest,
ActionListener.<IndexResponse>wrap(indexResponse -> {
ActionListener<Boolean> wrappedListener =
ActionListener.wrap(ignore -> listener.onResponse(true), listener::onFailure);
indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L);
}, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e);
traceLog("(bwc) invalidate token", tokenDocId, cause);
if (cause instanceof VersionConflictEngineException) {
// expected since something else could have invalidated
ActionListener<Boolean> wrappedListener =
ActionListener.wrap(ignore -> listener.onResponse(false), listener::onFailure);
indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L);
} else if (isShardNotAvailableException(e)) {
attemptCount.incrementAndGet();
indexBwcInvalidation(userToken, listener, attemptCount, expirationEpochMilli);
} else {
listener.onFailure(e);
}
}), client::index));
if (Strings.isNullOrEmpty(realmName)) {
findActiveTokensForUser(username, ActionListener.wrap(tokenTuples -> {
if (tokenTuples.isEmpty()) {
logger.warn("No tokens to invalidate for realm [{}] and username [{}]", realmName, username);
listener.onResponse(TokensInvalidationResult.emptyResult());
} else {
invalidateAllTokens(tokenTuples.stream().map(t -> t.v1().getId()).collect(Collectors.toList()), listener);
}
}, listener::onFailure));
} else {
Predicate filter = null;
if (Strings.hasText(username)) {
filter = isOfUser(username);
}
findActiveTokensForRealm(realmName, ActionListener.wrap(tokenTuples -> {
if (tokenTuples.isEmpty()) {
logger.warn("No tokens to invalidate for realm [{}] and username [{}]", realmName, username);
listener.onResponse(TokensInvalidationResult.emptyResult());
} else {
invalidateAllTokens(tokenTuples.stream().map(t -> t.v1().getId()).collect(Collectors.toList()), listener);
}
}, listener::onFailure), filter);
}
}
}
/**
* Performs the actual invalidation of a token
* Invalidates a collection of access_token and refresh_token that were retrieved by
* {@link TokenService#invalidateActiveTokensForRealmAndUser}
*
* @param tokenDocId the id of the token doc to invalidate
* @param accessTokenIds The ids of the access tokens which should be invalidated (along with the respective refresh_token)
* @param listener the listener to notify upon completion
*/
private void invalidateAllTokens(Collection<String> accessTokenIds, ActionListener<TokensInvalidationResult> listener) {
maybeStartTokenRemover();
final long expirationEpochMilli = getExpirationTime().toEpochMilli();
// Invalidate the refresh tokens first so that they cannot be used to get new
// access tokens while we invalidate the access tokens we currently know about
indexInvalidation(accessTokenIds, ActionListener.wrap(result ->
indexBwcInvalidation(accessTokenIds, listener, new AtomicInteger(result.getAttemptCount()),
expirationEpochMilli, result),
listener::onFailure), new AtomicInteger(0), "refresh_token", null);
}
/**
* Performs the actual bwc invalidation of a collection of tokens and then kicks off the new invalidation method.
*
* @param tokenIds the collection of token ids or token document ids that should be invalidated
* @param listener the listener to notify upon completion
* @param attemptCount the number of attempts to invalidate that have already been tried
* @param expirationEpochMilli the expiration time as milliseconds since the epoch
* @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating
* tokens up to the point of the retry. This result is added to the result of the current attempt
*/
private void indexBwcInvalidation(Collection<String> tokenIds, ActionListener<TokensInvalidationResult> listener,
AtomicInteger attemptCount, long expirationEpochMilli,
@Nullable TokensInvalidationResult previousResult) {
if (tokenIds.isEmpty()) {
logger.warn("No tokens provided for invalidation");
listener.onFailure(invalidGrantException("No tokens provided for invalidation"));
} else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(),
attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate tokens"));
} else {
BulkRequestBuilder bulkRequestBuilder = client.prepareBulk();
for (String tokenId : tokenIds) {
final String invalidatedTokenId = getInvalidatedTokenDocumentId(tokenId);
IndexRequest indexRequest = client.prepareIndex(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId)
.setOpType(OpType.CREATE)
.setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli)
.request();
bulkRequestBuilder.add(indexRequest);
}
bulkRequestBuilder.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL);
final BulkRequest bulkRequest = bulkRequestBuilder.request();
securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequest,
ActionListener.<BulkResponse>wrap(bulkResponse -> {
List<String> retryTokenIds = new ArrayList<>();
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
if (bulkItemResponse.isFailed()) {
Throwable cause = bulkItemResponse.getFailure().getCause();
logger.error(cause.getMessage());
traceLog("(bwc) invalidate tokens", cause);
if (isShardNotAvailableException(cause)) {
retryTokenIds.add(getTokenIdFromInvalidatedTokenDocumentId(bulkItemResponse.getFailure().getId()));
} else if ((cause instanceof VersionConflictEngineException) == false){
// We don't handle VersionConflictEngineException, the ticket has been invalidated
listener.onFailure(bulkItemResponse.getFailure().getCause());
}
}
}
if (retryTokenIds.isEmpty() == false) {
attemptCount.incrementAndGet();
indexBwcInvalidation(retryTokenIds, listener, attemptCount, expirationEpochMilli, previousResult);
}
indexInvalidation(tokenIds, listener, attemptCount, "access_token", previousResult);
}, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e);
traceLog("(bwc) invalidate tokens", cause);
if (isShardNotAvailableException(cause)) {
attemptCount.incrementAndGet();
indexBwcInvalidation(tokenIds, listener, attemptCount, expirationEpochMilli, previousResult);
} else {
listener.onFailure(e);
}
}),
client::bulk));
}
}
/**
* Performs the actual invalidation of a collection of tokens
*
* @param tokenIds the tokens to invalidate
* @param listener the listener to notify upon completion
* @param attemptCount the number of attempts to invalidate that have already been tried
* @param srcPrefix the prefix to use when constructing the doc to update
* @param documentVersion the expected version of the document we will update
* @param srcPrefix the prefix to use when constructing the doc to update, either refresh_token or access_token depending on
* what type of tokens should be invalidated
* @param previousResult if this not the initial attempt for invalidation, it contains the result of invalidating
* tokens up to the point of the retry. This result is added to the result of the current attempt
*/
private void indexInvalidation(String tokenDocId, Version version, ActionListener<Boolean> listener, AtomicInteger attemptCount,
String srcPrefix, long documentVersion) {
if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate token [{}] after [{}] attempts", tokenDocId, attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate token"));
private void indexInvalidation(Collection<String> tokenIds, ActionListener<TokensInvalidationResult> listener,
AtomicInteger attemptCount, String srcPrefix, @Nullable TokensInvalidationResult previousResult) {
if (tokenIds.isEmpty()) {
logger.warn("No [{}] tokens provided for invalidation", srcPrefix);
listener.onFailure(invalidGrantException("No tokens provided for invalidation"));
} else if (attemptCount.get() > MAX_RETRY_ATTEMPTS) {
logger.warn("Failed to invalidate [{}] tokens after [{}] attempts", tokenIds.size(),
attemptCount.get());
listener.onFailure(invalidGrantException("failed to invalidate tokens"));
} else {
UpdateRequest request = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId)
BulkRequestBuilder bulkRequestBuilder = client.prepareBulk();
for (String tokenId : tokenIds) {
UpdateRequest request = client.prepareUpdate(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, getTokenDocumentId(tokenId))
.setDoc(srcPrefix, Collections.singletonMap("invalidated", true))
.setVersion(documentVersion)
.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL)
.setFetchSource(srcPrefix, null)
.request();
securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", tokenDocId, ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request,
ActionListener.<UpdateResponse>wrap(updateResponse -> {
logger.debug("Invalidated [{}] for doc [{}]", srcPrefix, tokenDocId);
if (updateResponse.getGetResult() != null
&& updateResponse.getGetResult().sourceAsMap().containsKey(srcPrefix)
&& ((Map<String, Object>) updateResponse.getGetResult().sourceAsMap().get(srcPrefix))
.containsKey("invalidated")) {
final boolean prevInvalidated = (boolean)
((Map<String, Object>) updateResponse.getGetResult().sourceAsMap().get(srcPrefix))
.get("invalidated");
listener.onResponse(prevInvalidated == false);
} else {
listener.onResponse(true);
bulkRequestBuilder.add(request);
}
bulkRequestBuilder.setRefreshPolicy(RefreshPolicy.WAIT_UNTIL);
securityIndex.prepareIndexIfNeededThenExecute(ex -> listener.onFailure(traceLog("prepare security index", ex)),
() -> executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, bulkRequestBuilder.request(),
ActionListener.<BulkResponse>wrap(bulkResponse -> {
ArrayList<String> retryTokenDocIds = new ArrayList<>();
ArrayList<ElasticsearchException> failedRequestResponses = new ArrayList<>();
ArrayList<String> previouslyInvalidated = new ArrayList<>();
ArrayList<String> invalidated = new ArrayList<>();
if (null != previousResult) {
failedRequestResponses.addAll((previousResult.getErrors()));
previouslyInvalidated.addAll(previousResult.getPreviouslyInvalidatedTokens());
invalidated.addAll(previousResult.getInvalidatedTokens());
}
for (BulkItemResponse bulkItemResponse : bulkResponse.getItems()) {
if (bulkItemResponse.isFailed()) {
Throwable cause = bulkItemResponse.getFailure().getCause();
final String failedTokenDocId = getTokenIdFromDocumentId(bulkItemResponse.getFailure().getId());
if (isShardNotAvailableException(cause)) {
retryTokenDocIds.add(failedTokenDocId);
}
else {
traceLog("invalidate access token", failedTokenDocId, cause);
failedRequestResponses.add(new ElasticsearchException("Error invalidating " + srcPrefix + ": ", cause));
}
} else {
UpdateResponse updateResponse = bulkItemResponse.getResponse();
if (updateResponse.getResult() == DocWriteResponse.Result.UPDATED) {
logger.debug("Invalidated [{}] for doc [{}]", srcPrefix, updateResponse.getGetResult().getId());
invalidated.add(updateResponse.getGetResult().getId());
} else if (updateResponse.getResult() == DocWriteResponse.Result.NOOP) {
previouslyInvalidated.add(updateResponse.getGetResult().getId());
}
}
}
if (retryTokenDocIds.isEmpty() == false) {
TokensInvalidationResult incompleteResult = new TokensInvalidationResult(invalidated, previouslyInvalidated,
failedRequestResponses, attemptCount.get());
attemptCount.incrementAndGet();
indexInvalidation(retryTokenDocIds, listener, attemptCount, srcPrefix, incompleteResult);
}
TokensInvalidationResult result = new TokensInvalidationResult(invalidated, previouslyInvalidated,
failedRequestResponses, attemptCount.get());
listener.onResponse(result);
}, e -> {
Throwable cause = ExceptionsHelper.unwrapCause(e);
traceLog("invalidate token", tokenDocId, cause);
if (cause instanceof DocumentMissingException) {
if (version.onOrAfter(Version.V_6_2_0)) {
// the document should always be there!
listener.onFailure(e);
} else {
listener.onResponse(false);
}
} else if (cause instanceof VersionConflictEngineException
|| isShardNotAvailableException(cause)) {
traceLog("invalidate tokens", cause);
if (isShardNotAvailableException(cause)) {
attemptCount.incrementAndGet();
executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN,
client.prepareGet(SecurityIndexManager.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(),
ActionListener.<GetResponse>wrap(getResult -> {
if (getResult.isExists()) {
Map<String, Object> source = getResult.getSource();
Map<String, Object> accessTokenSource = (Map<String, Object>) source.get("access_token");
Consumer<Exception> onFailure = ex -> listener.onFailure(traceLog("get token", tokenDocId, ex));
if (accessTokenSource == null) {
onFailure.accept(new IllegalArgumentException(
"token document is missing access_token field"));
} else {
Boolean invalidated = (Boolean) accessTokenSource.get("invalidated");
if (invalidated == null) {
onFailure.accept(new IllegalStateException(
"token document missing invalidated value"));
} else if (invalidated) {
logger.trace("Token [{}] is already invalidated", tokenDocId);
listener.onResponse(false);
} else {
indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix,
getResult.getVersion());
}
}
} else if (version.onOrAfter(Version.V_6_2_0)) {
logger.warn("could not find token document [{}] but there should " +
"be one as token has version [{}]", tokenDocId, version);
listener.onFailure(invalidGrantException("could not invalidate the token"));
} else {
listener.onResponse(false);
}
},
e1 -> {
traceLog("get token", tokenDocId, e1);
if (isShardNotAvailableException(e1)) {
// don't increment count; call again
indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix,
documentVersion);
} else {
listener.onFailure(e1);
}
}), client::get);
indexInvalidation(tokenIds, listener, attemptCount, srcPrefix, previousResult);
} else {
listener.onFailure(e);
}
}), client::update));
}), client::bulk));
}
}
@ -676,12 +761,12 @@ public final class TokenService {
public void refreshToken(String refreshToken, ActionListener<Tuple<UserToken, String>> listener) {
ensureEnabled();
findTokenFromRefreshToken(refreshToken,
ActionListener.wrap(tuple -> {
final Authentication userAuth = Authentication.readFromContext(client.threadPool().getThreadContext());
final String tokenDocId = tuple.v1().getHits().getHits()[0].getId();
innerRefresh(tokenDocId, userAuth, listener, tuple.v2());
}, listener::onFailure),
new AtomicInteger(0));
ActionListener.wrap(tuple -> {
final Authentication userAuth = Authentication.readFromContext(client.threadPool().getThreadContext());
final String tokenDocId = tuple.v1().getHits().getHits()[0].getId();
innerRefresh(tokenDocId, userAuth, listener, tuple.v2());
}, listener::onFailure),
new AtomicInteger(0));
}
private void findTokenFromRefreshToken(String refreshToken, ActionListener<Tuple<SearchResponse, AtomicInteger>> listener,
@ -691,11 +776,11 @@ public final class TokenService {
listener.onFailure(invalidGrantException("could not refresh the requested token"));
} else {
SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setQuery(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
.setVersion(true)
.request();
.setQuery(QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.termQuery("refresh_token.token", refreshToken)))
.setVersion(true)
.request();
final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
if (frozenSecurityIndex.indexExists() == false) {
@ -860,12 +945,16 @@ public final class TokenService {
}
/**
* Find all stored refresh and access tokens that have not been invalidated or expired, and were issued against
* Find stored refresh and access tokens that have not been invalidated or expired, and were issued against
* the specified realm.
*
* @param realmName The name of the realm for which to get the tokens
* @param listener The listener to notify upon completion
* @param filter an optional Predicate to test the source of the found documents against
*/
public void findActiveTokensForRealm(String realmName, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
public void findActiveTokensForRealm(String realmName, ActionListener<Collection<Tuple<UserToken, String>>> listener,
@Nullable Predicate<Map<String, Object>> filter) {
ensureEnabled();
final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
if (Strings.isNullOrEmpty(realmName)) {
listener.onFailure(new IllegalArgumentException("Realm name is required"));
@ -883,7 +972,10 @@ public final class TokenService {
.must(QueryBuilders.termQuery("access_token.invalidated", false))
.must(QueryBuilders.rangeQuery("access_token.user_token.expiration_time").gte(now.toEpochMilli()))
)
.should(QueryBuilders.termQuery("refresh_token.invalidated", false))
.should(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("refresh_token.invalidated", false))
.must(QueryBuilders.rangeQuery("creation_time").gte(now.toEpochMilli() - TimeValue.timeValueHours(24).millis()))
)
);
final SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
@ -893,33 +985,102 @@ public final class TokenService {
.setSize(1000)
.setFetchSource(true)
.request();
securityIndex.checkIndexVersionThenExecute(listener::onFailure,
() -> ScrollHelper.fetchAllByEntity(client, request, listener, this::parseHit));
() -> ScrollHelper.fetchAllByEntity(client, request, listener, (SearchHit hit) -> filterAndParseHit(hit, filter)));
}
}
private Tuple<UserToken, String> parseHit(SearchHit hit) {
/**
* Find stored refresh and access tokens that have not been invalidated or expired, and were issued for
* the specified user.
*
* @param username The user for which to get the tokens
* @param listener The listener to notify upon completion
*/
public void findActiveTokensForUser(String username, ActionListener<Collection<Tuple<UserToken, String>>> listener) {
ensureEnabled();
final SecurityIndexManager frozenSecurityIndex = securityIndex.freeze();
if (Strings.isNullOrEmpty(username)) {
listener.onFailure(new IllegalArgumentException("username is required"));
} else if (frozenSecurityIndex.indexExists() == false) {
listener.onResponse(Collections.emptyList());
} else if (frozenSecurityIndex.isAvailable() == false) {
listener.onFailure(frozenSecurityIndex.getUnavailableReason());
} else {
final Instant now = clock.instant();
final BoolQueryBuilder boolQuery = QueryBuilders.boolQuery()
.filter(QueryBuilders.termQuery("doc_type", "token"))
.filter(QueryBuilders.boolQuery()
.should(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("access_token.invalidated", false))
.must(QueryBuilders.rangeQuery("access_token.user_token.expiration_time").gte(now.toEpochMilli()))
)
.should(QueryBuilders.boolQuery()
.must(QueryBuilders.termQuery("refresh_token.invalidated", false))
.must(QueryBuilders.rangeQuery("creation_time").gte(now.toEpochMilli() - TimeValue.timeValueHours(24).millis()))
)
);
final SearchRequest request = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
.setScroll(DEFAULT_KEEPALIVE_SETTING.get(settings))
.setQuery(boolQuery)
.setVersion(false)
.setSize(1000)
.setFetchSource(true)
.request();
securityIndex.checkIndexVersionThenExecute(listener::onFailure,
() -> ScrollHelper.fetchAllByEntity(client, request, listener,
(SearchHit hit) -> filterAndParseHit(hit, isOfUser(username))));
}
}
private static Predicate<Map<String, Object>> isOfUser(String username) {
return source -> {
String auth = (String) source.get("authentication");
Integer version = (Integer) source.get("version");
Version authVersion = Version.fromId(version);
try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(auth))) {
in.setVersion(authVersion);
Authentication authentication = new Authentication(in);
return authentication.getUser().principal().equals(username);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
};
}
private Tuple<UserToken, String> filterAndParseHit(SearchHit hit, @Nullable Predicate<Map<String, Object>> filter) {
final Map<String, Object> source = hit.getSourceAsMap();
if (source == null) {
throw new IllegalStateException("token document did not have source but source should have been fetched");
}
try {
return parseTokensFromDocument(source);
return parseTokensFromDocument(source, filter);
} catch (IOException e) {
throw invalidGrantException("cannot read token from document");
}
}
/**
* @return A {@link Tuple} of access-token and refresh-token-id
*
* Parses a token document into a Tuple of a {@link UserToken} and a String representing the corresponding refresh_token
*
* @param source The token document source as retrieved
* @param filter an optional Predicate to test the source of the UserToken against
* @return A {@link Tuple} of access-token and refresh-token-id or null if a Predicate is defined and the userToken source doesn't
* satisfy it
*/
private Tuple<UserToken, String> parseTokensFromDocument(Map<String, Object> source) throws IOException {
final String refreshToken = (String) ((Map<String, Object>) source.get("refresh_token")).get("token");
private Tuple<UserToken, String> parseTokensFromDocument(Map<String, Object> source, @Nullable Predicate<Map<String, Object>> filter)
throws IOException {
final String refreshToken = (String) ((Map<String, Object>) source.get("refresh_token")).get("token");
final Map<String, Object> userTokenSource = (Map<String, Object>)
((Map<String, Object>) source.get("access_token")).get("user_token");
((Map<String, Object>) source.get("access_token")).get("user_token");
if (null != filter && filter.test(userTokenSource) == false) {
return null;
}
final String id = (String) userTokenSource.get("id");
final Integer version = (Integer) userTokenSource.get("version");
final String authString = (String) userTokenSource.get("authentication");
@ -951,6 +1112,23 @@ public final class TokenService {
return "token_" + id;
}
private static String getTokenIdFromDocumentId(String docId) {
if (docId.startsWith("token_") == false) {
throw new IllegalStateException("TokenDocument ID [" + docId + "] has unexpected value");
} else {
return docId.substring("token_".length());
}
}
private static String getTokenIdFromInvalidatedTokenDocumentId(String docId) {
final String invalidatedTokenDocPrefix = INVALIDATED_TOKEN_DOC_TYPE + "_";
if (docId.startsWith(invalidatedTokenDocPrefix) == false) {
throw new IllegalStateException("InvalidatedTokenDocument ID [" + docId + "] has unexpected value");
} else {
return docId.substring(invalidatedTokenDocPrefix.length());
}
}
private void ensureEnabled() {
if (enabled == false) {
throw new IllegalStateException("tokens are not enabled");
@ -1149,7 +1327,7 @@ public final class TokenService {
}
/**
* Creates an {@link ElasticsearchSecurityException} that indicates the token was expired. It
* Creates an {@link ElasticsearchSecurityException} that indicates the token was malformed. It
* is up to the client to re-authenticate and obtain a new token. The format for this response
* is defined in <a href="https://tools.ietf.org/html/rfc6750#section-3.1"></a>
*/
@ -1171,7 +1349,7 @@ public final class TokenService {
}
/**
* Logs an exception at TRACE level (if enabled)
* Logs an exception concerning a specific Token at TRACE level (if enabled)
*/
private <E extends Throwable> E traceLog(String action, String identifier, E exception) {
if (logger.isTraceEnabled()) {
@ -1179,12 +1357,34 @@ public final class TokenService {
final ElasticsearchException esEx = (ElasticsearchException) exception;
final Object detail = esEx.getHeader("error_description");
if (detail != null) {
logger.trace("Failure in [{}] for id [{}] - [{}] [{}]", action, identifier, detail, esEx.getDetailedMessage());
logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}] - [{}]", action, identifier, detail),
esEx);
} else {
logger.trace("Failure in [{}] for id [{}] - [{}]", action, identifier, esEx.getDetailedMessage());
logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}]", action, identifier),
esEx);
}
} else {
logger.trace("Failure in [{}] for id [{}] - [{}]", action, identifier, exception.toString());
logger.trace(() -> new ParameterizedMessage("Failure in [{}] for id [{}]", action, identifier), exception);
}
}
return exception;
}
/**
* Logs an exception at TRACE level (if enabled)
*/
private <E extends Throwable> E traceLog(String action, E exception) {
if (logger.isTraceEnabled()) {
if (exception instanceof ElasticsearchException) {
final ElasticsearchException esEx = (ElasticsearchException) exception;
final Object detail = esEx.getHeader("error_description");
if (detail != null) {
logger.trace(() -> new ParameterizedMessage("Failure in [{}] - [{}]", action, detail), esEx);
} else {
logger.trace(() -> new ParameterizedMessage("Failure in [{}]", action), esEx);
}
} else {
logger.trace(() -> new ParameterizedMessage("Failure in [{}]", action), exception);
}
}
return exception;

View File

@ -9,7 +9,6 @@ import org.apache.logging.log4j.LogManager;
import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.common.ParseField;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.logging.DeprecationLogger;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.ConstructingObjectParser;
@ -37,11 +36,32 @@ import static org.elasticsearch.rest.RestRequest.Method.DELETE;
public final class RestInvalidateTokenAction extends SecurityBaseRestHandler {
private static final DeprecationLogger deprecationLogger = new DeprecationLogger(LogManager.getLogger(RestInvalidateTokenAction.class));
static final ConstructingObjectParser<Tuple<String, String>, Void> PARSER =
new ConstructingObjectParser<>("invalidate_token", a -> new Tuple<>((String) a[0], (String) a[1]));
static final ConstructingObjectParser<InvalidateTokenRequest, Void> PARSER =
new ConstructingObjectParser<>("invalidate_token", a -> {
final String token = (String) a[0];
final String refreshToken = (String) a[1];
final String tokenString;
final String tokenType;
if (Strings.hasLength(token) && Strings.hasLength(refreshToken)) {
throw new IllegalArgumentException("only one of [token, refresh_token] may be sent per request");
} else if (Strings.hasLength(token)) {
tokenString = token;
tokenType = InvalidateTokenRequest.Type.ACCESS_TOKEN.getValue();
} else if (Strings.hasLength(refreshToken)) {
tokenString = refreshToken;
tokenType = InvalidateTokenRequest.Type.REFRESH_TOKEN.getValue();
} else {
tokenString = null;
tokenType = null;
}
return new InvalidateTokenRequest(tokenString, tokenType, (String) a[2], (String) a[3]);
});
static {
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("token"));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refresh_token"));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("realm_name"));
PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("username"));
}
public RestInvalidateTokenAction(Settings settings, RestController controller, XPackLicenseState xPackLicenseState) {
@ -60,36 +80,16 @@ public final class RestInvalidateTokenAction extends SecurityBaseRestHandler {
@Override
protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client) throws IOException {
try (XContentParser parser = request.contentParser()) {
final Tuple<String, String> tuple = PARSER.parse(parser, null);
final String token = tuple.v1();
final String refreshToken = tuple.v2();
final String tokenString;
final InvalidateTokenRequest.Type type;
if (Strings.hasLength(token) && Strings.hasLength(refreshToken)) {
throw new IllegalArgumentException("only one of [token, refresh_token] may be sent per request");
} else if (Strings.hasLength(token)) {
tokenString = token;
type = InvalidateTokenRequest.Type.ACCESS_TOKEN;
} else if (Strings.hasLength(refreshToken)) {
tokenString = refreshToken;
type = InvalidateTokenRequest.Type.REFRESH_TOKEN;
} else {
tokenString = null;
type = null;
}
final InvalidateTokenRequest tokenRequest = new InvalidateTokenRequest(tokenString, type);
return channel -> client.execute(InvalidateTokenAction.INSTANCE, tokenRequest,
new RestBuilderListener<InvalidateTokenResponse>(channel) {
@Override
public RestResponse buildResponse(InvalidateTokenResponse invalidateResp,
XContentBuilder builder) throws Exception {
return new BytesRestResponse(RestStatus.OK, builder.startObject()
.field("created", invalidateResp.isCreated())
.endObject());
}
});
final InvalidateTokenRequest invalidateTokenRequest = PARSER.parse(parser, null);
return channel -> client.execute(InvalidateTokenAction.INSTANCE, invalidateTokenRequest,
new RestBuilderListener<InvalidateTokenResponse>(channel) {
@Override
public RestResponse buildResponse(InvalidateTokenResponse invalidateResp,
XContentBuilder builder) throws Exception {
invalidateResp.toXContent(builder, channel.request());
return new BytesRestResponse(RestStatus.OK, builder);
}
});
}
}
}

View File

@ -11,6 +11,10 @@ import org.elasticsearch.action.Action;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.index.IndexAction;
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.index.IndexResponse;
@ -21,11 +25,11 @@ import org.elasticsearch.action.search.SearchAction;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchResponseSections;
import org.elasticsearch.action.search.SearchScrollAction;
import org.elasticsearch.action.search.SearchScrollRequest;
import org.elasticsearch.action.support.ActionFilters;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.update.UpdateAction;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.bytes.BytesReference;
@ -106,11 +110,12 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
private SamlRealm samlRealm;
private TokenService tokenService;
private List<IndexRequest> indexRequests;
private List<UpdateRequest> updateRequests;
private List<BulkRequest> bulkRequests;
private List<SearchRequest> searchRequests;
private TransportSamlInvalidateSessionAction action;
private SamlLogoutRequestHandler.Result logoutRequest;
private Function<SearchRequest, SearchHit[]> searchFunction = ignore -> new SearchHit[0];
private Function<SearchScrollRequest, SearchHit[]> searchScrollFunction = ignore -> new SearchHit[0];
@Before
public void setup() throws Exception {
@ -132,8 +137,8 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
new Authentication(new User("kibana"), new RealmRef("realm", "type", "node"), null).writeToContext(threadContext);
indexRequests = new ArrayList<>();
updateRequests = new ArrayList<>();
searchRequests = new ArrayList<>();
bulkRequests = new ArrayList<>();
final Client client = new NoOpClient(threadPool) {
@Override
protected <Request extends ActionRequest, Response extends ActionResponse>
@ -143,20 +148,29 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
IndexRequest indexRequest = (IndexRequest) request;
indexRequests.add(indexRequest);
final IndexResponse response = new IndexResponse(
indexRequest.shardId(), indexRequest.type(), indexRequest.id(), 1, 1, 1, true);
indexRequest.shardId(), indexRequest.type(), indexRequest.id(), 1, 1, 1, true);
listener.onResponse((Response) response);
} else if (BulkAction.NAME.equals(action.name())) {
assertThat(request, instanceOf(BulkRequest.class));
bulkRequests.add((BulkRequest) request);
final BulkResponse response = new BulkResponse(new BulkItemResponse[0], 1);
listener.onResponse((Response) response);
} else if (UpdateAction.NAME.equals(action.name())) {
assertThat(request, instanceOf(UpdateRequest.class));
updateRequests.add((UpdateRequest) request);
listener.onResponse((Response) new UpdateResponse());
} else if (SearchAction.NAME.equals(action.name())) {
assertThat(request, instanceOf(SearchRequest.class));
SearchRequest searchRequest = (SearchRequest) request;
searchRequests.add(searchRequest);
final SearchHit[] hits = searchFunction.apply(searchRequest);
final SearchResponse response = new SearchResponse(
new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f),
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f),
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
listener.onResponse((Response) response);
} else if (SearchScrollAction.NAME.equals(action.name())){
assertThat(request, instanceOf(SearchScrollRequest.class));
SearchScrollRequest searchScrollRequest = (SearchScrollRequest) request;
final SearchHit[] hits = searchScrollFunction.apply(searchScrollRequest);
final SearchResponse response = new SearchResponse(
new SearchResponseSections(new SearchHits(hits, new TotalHits(hits.length, TotalHits.Relation.EQUAL_TO), 0f),
null, null, false, false, null, 1), "_scrollId1", 1, 1, 0, 1, null, null);
listener.onResponse((Response) response);
} else if (ClearScrollAction.NAME.equals(action.name())) {
assertThat(request, instanceOf(ClearScrollRequest.class));
@ -296,15 +310,33 @@ public class TransportSamlInvalidateSessionActionTests extends SamlTestCase {
assertThat(((TermQueryBuilder) filter1.get(1)).fieldName(), equalTo("refresh_token.token"));
assertThat(((TermQueryBuilder) filter1.get(1)).value(), equalTo(tokenToInvalidate1.v2()));
assertThat(updateRequests.size(), equalTo(4)); // (refresh-token + access-token) * 2
assertThat(updateRequests.get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
assertThat(updateRequests.get(1).id(), equalTo(updateRequests.get(0).id()));
assertThat(updateRequests.get(2).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
assertThat(updateRequests.get(3).id(), equalTo(updateRequests.get(2).id()));
assertThat(indexRequests.size(), equalTo(2)); // bwc-invalidate * 2
assertThat(indexRequests.get(0).id(), startsWith("invalidated-token_"));
assertThat(indexRequests.get(1).id(), startsWith("invalidated-token_"));
assertThat(bulkRequests.size(), equalTo(6)); // 4 updates (refresh-token + access-token) plus 2 indexes (bwc-invalidate * 2)
// Invalidate refresh token 1
assertThat(bulkRequests.get(0).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(0).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
UpdateRequest updateRequest1 = (UpdateRequest) bulkRequests.get(0).requests().get(0);
assertThat(updateRequest1.toString().contains("refresh_token"), equalTo(true));
// BWC incalidate access token 1
assertThat(bulkRequests.get(1).requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequests.get(1).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate1.v1().getId()));
// Invalidate access token 1
assertThat(bulkRequests.get(2).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(2).requests().get(0).id(), equalTo("token_" + tokenToInvalidate1.v1().getId()));
UpdateRequest updateRequest2 = (UpdateRequest) bulkRequests.get(2).requests().get(0);
assertThat(updateRequest2.toString().contains("access_token"), equalTo(true));
// Invalidate refresh token 2
assertThat(bulkRequests.get(3).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(3).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
UpdateRequest updateRequest3 = (UpdateRequest) bulkRequests.get(3).requests().get(0);
assertThat(updateRequest3.toString().contains("refresh_token"), equalTo(true));
// BWC incalidate access token 2
assertThat(bulkRequests.get(4).requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequests.get(4).requests().get(0).id(), equalTo("invalidated-token_" + tokenToInvalidate2.v1().getId()));
// Invalidate access token 2
assertThat(bulkRequests.get(5).requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequests.get(5).requests().get(0).id(), equalTo("token_" + tokenToInvalidate2.v1().getId()));
UpdateRequest updateRequest4 = (UpdateRequest) bulkRequests.get(5).requests().get(0);
assertThat(updateRequest4.toString().contains("access_token"), equalTo(true));
}
private Function<SearchRequest, SearchHit[]> findTokenByRefreshToken(SearchHit[] searchHits) {

View File

@ -6,7 +6,11 @@
package org.elasticsearch.xpack.security.action.saml;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.DocWriteResponse;
import org.elasticsearch.action.bulk.BulkAction;
import org.elasticsearch.action.bulk.BulkItemResponse;
import org.elasticsearch.action.bulk.BulkRequest;
import org.elasticsearch.action.bulk.BulkRequestBuilder;
import org.elasticsearch.action.bulk.BulkResponse;
import org.elasticsearch.action.get.GetAction;
import org.elasticsearch.action.get.GetRequestBuilder;
import org.elasticsearch.action.get.GetResponse;
@ -24,7 +28,6 @@ import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.action.update.UpdateAction;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.action.update.UpdateRequestBuilder;
import org.elasticsearch.action.update.UpdateResponse;
import org.elasticsearch.client.Client;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.collect.MapBuilder;
@ -72,6 +75,9 @@ import java.util.function.Consumer;
import static org.elasticsearch.xpack.core.security.authc.RealmSettings.getFullSettingKey;
import static org.elasticsearch.xpack.security.authc.TokenServiceTests.mockGetTokenFromId;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.notNullValue;
import static org.hamcrest.Matchers.startsWith;
import static org.mockito.Matchers.any;
@ -89,7 +95,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
private SamlRealm samlRealm;
private TokenService tokenService;
private List<IndexRequest> indexRequests;
private List<UpdateRequest> updateRequests;
private List<BulkRequest> bulkRequests;
private TransportSamlLogoutAction action;
private Client client;
@ -112,7 +118,7 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
new Authentication(new User("kibana"), new Authentication.RealmRef("realm", "type", "node"), null).writeToContext(threadContext);
indexRequests = new ArrayList<>();
updateRequests = new ArrayList<>();
bulkRequests = new ArrayList<>();
client = mock(Client.class);
when(client.threadPool()).thenReturn(threadPool);
when(client.settings()).thenReturn(settings);
@ -137,6 +143,10 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
.setId((String) invocationOnMock.getArguments()[2]);
return builder;
}).when(client).prepareUpdate(anyString(), anyString(), anyString());
doAnswer(invocationOnMock -> {
BulkRequestBuilder builder = new BulkRequestBuilder(client, BulkAction.INSTANCE);
return builder;
}).when(client).prepareBulk();
when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE));
doAnswer(invocationOnMock -> {
ActionListener<MultiGetResponse> listener = (ActionListener<MultiGetResponse>) invocationOnMock.getArguments()[1];
@ -154,15 +164,6 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
listener.onResponse(response);
return Void.TYPE;
}).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class));
doAnswer(invocationOnMock -> {
UpdateRequest updateRequest = (UpdateRequest) invocationOnMock.getArguments()[0];
ActionListener<UpdateResponse> listener = (ActionListener<UpdateResponse>) invocationOnMock.getArguments()[1];
updateRequests.add(updateRequest);
final UpdateResponse response = new UpdateResponse(
updateRequest.getShardId(), updateRequest.type(), updateRequest.id(), 1, DocWriteResponse.Result.UPDATED);
listener.onResponse(response);
return Void.TYPE;
}).when(client).update(any(UpdateRequest.class), any(ActionListener.class));
doAnswer(invocationOnMock -> {
IndexRequest indexRequest = (IndexRequest) invocationOnMock.getArguments()[0];
ActionListener<IndexResponse> listener = (ActionListener<IndexResponse>) invocationOnMock.getArguments()[1];
@ -181,6 +182,14 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
listener.onResponse(response);
return Void.TYPE;
}).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class));
doAnswer(invocationOnMock -> {
BulkRequest bulkRequest = (BulkRequest) invocationOnMock.getArguments()[0];
ActionListener<BulkResponse> listener = (ActionListener<BulkResponse>) invocationOnMock.getArguments()[1];
bulkRequests.add(bulkRequest);
final BulkResponse response = new BulkResponse(new BulkItemResponse[0], 1);
listener.onResponse(response);
return Void.TYPE;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));
final SecurityIndexManager securityIndex = mock(SecurityIndexManager.class);
doAnswer(inv -> {
@ -247,9 +256,17 @@ public class TransportSamlLogoutActionTests extends SamlTestCase {
assertThat(indexRequest1, notNullValue());
assertThat(indexRequest1.id(), startsWith("token"));
final IndexRequest indexRequest2 = indexRequests.get(1);
assertThat(indexRequest2, notNullValue());
assertThat(indexRequest2.id(), startsWith("invalidated-token"));
assertThat(bulkRequests.size(), equalTo(2));
final BulkRequest bulkRequest1 = bulkRequests.get(0);
assertThat(bulkRequest1.requests().size(), equalTo(1));
assertThat(bulkRequest1.requests().get(0), instanceOf(IndexRequest.class));
assertThat(bulkRequest1.requests().get(0).id(), startsWith("invalidated-token_"));
final BulkRequest bulkRequest2 = bulkRequests.get(1);
assertThat(bulkRequest2.requests().size(), equalTo(1));
assertThat(bulkRequest2.requests().get(0), instanceOf(UpdateRequest.class));
assertThat(bulkRequest2.requests().get(0).id(), startsWith("token_"));
assertThat(bulkRequest2.requests().get(0).toString(), containsString("\"access_token\":{\"invalidated\":true"));
}
}

View File

@ -144,7 +144,9 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
.prepareInvalidateToken(response.getTokenString())
.setType(InvalidateTokenRequest.Type.ACCESS_TOKEN)
.get();
assertTrue(invalidateResponse.isCreated());
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
AtomicReference<String> docId = new AtomicReference<>();
assertBusy(() -> {
SearchResponse searchResponse = client.prepareSearch(SecurityIndexManager.SECURITY_INDEX_NAME)
@ -189,6 +191,72 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
}, 30, TimeUnit.SECONDS);
}
public void testInvalidateAllTokensForUser() throws Exception{
final int numOfRequests = randomIntBetween(5, 10);
for (int i = 0; i < numOfRequests; i++) {
securityClient().prepareCreateToken()
.setGrantType("password")
.setUsername(SecuritySettingsSource.TEST_USER_NAME)
.setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()))
.get();
}
Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient securityClientSuperuser = new SecurityClient(client);
InvalidateTokenResponse invalidateResponse = securityClientSuperuser
.prepareInvalidateToken()
.setUserName(SecuritySettingsSource.TEST_USER_NAME)
.get();
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(2 * (numOfRequests)));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
}
public void testInvalidateAllTokensForRealm() throws Exception{
final int numOfRequests = randomIntBetween(5, 10);
for (int i = 0; i < numOfRequests; i++) {
securityClient().prepareCreateToken()
.setGrantType("password")
.setUsername(SecuritySettingsSource.TEST_USER_NAME)
.setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()))
.get();
}
Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient securityClientSuperuser = new SecurityClient(client);
InvalidateTokenResponse invalidateResponse = securityClientSuperuser
.prepareInvalidateToken()
.setRealmName("file")
.get();
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(2 * (numOfRequests)));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
}
public void testInvalidateAllTokensForRealmThatHasNone() {
final int numOfRequests = randomIntBetween(2, 4);
for (int i = 0; i < numOfRequests; i++) {
securityClient().prepareCreateToken()
.setGrantType("password")
.setUsername(SecuritySettingsSource.TEST_USER_NAME)
.setPassword(new SecureString(SecuritySettingsSourceField.TEST_PASSWORD.toCharArray()))
.get();
}
Client client = client().filterWithHeader(Collections.singletonMap("Authorization",
UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER,
SecuritySettingsSourceField.TEST_PASSWORD_SECURE_STRING)));
SecurityClient securityClientSuperuser = new SecurityClient(client);
InvalidateTokenResponse invalidateResponse = securityClientSuperuser
.prepareInvalidateToken()
.setRealmName("saml")
.get();
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
}
public void testExpireMultipleTimes() {
CreateTokenResponse response = securityClient().prepareCreateToken()
.setGrantType("password")
@ -200,12 +268,16 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
.prepareInvalidateToken(response.getTokenString())
.setType(InvalidateTokenRequest.Type.ACCESS_TOKEN)
.get();
assertTrue(invalidateResponse.isCreated());
assertFalse(securityClient()
.prepareInvalidateToken(response.getTokenString())
.setType(InvalidateTokenRequest.Type.ACCESS_TOKEN)
.get()
.isCreated());
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
InvalidateTokenResponse invalidateAgainResponse = securityClient()
.prepareInvalidateToken(response.getTokenString())
.setType(InvalidateTokenRequest.Type.ACCESS_TOKEN)
.get();
assertThat(invalidateAgainResponse.getResult().getInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateAgainResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(1));
assertThat(invalidateAgainResponse.getResult().getErrors().size(), equalTo(0));
}
public void testRefreshingToken() {
@ -248,7 +320,9 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
.prepareInvalidateToken(createTokenResponse.getRefreshToken())
.setType(InvalidateTokenRequest.Type.REFRESH_TOKEN)
.get();
assertTrue(invalidateResponse.isCreated());
assertThat(invalidateResponse.getResult().getInvalidatedTokens().size(), equalTo(1));
assertThat(invalidateResponse.getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponse.getResult().getErrors().size(), equalTo(0));
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class,
() -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get());
@ -362,9 +436,11 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase {
// invalidate
PlainActionFuture<InvalidateTokenResponse> invalidateResponseFuture = new PlainActionFuture<>();
InvalidateTokenRequest invalidateTokenRequest =
new InvalidateTokenRequest(createTokenResponse.getTokenString(), InvalidateTokenRequest.Type.ACCESS_TOKEN);
new InvalidateTokenRequest(createTokenResponse.getTokenString(), InvalidateTokenRequest.Type.ACCESS_TOKEN.getValue());
securityClient.invalidateToken(invalidateTokenRequest, invalidateResponseFuture);
assertTrue(invalidateResponseFuture.get().isCreated());
assertThat(invalidateResponseFuture.get().getResult().getInvalidatedTokens().size(), equalTo(1));
assertThat(invalidateResponseFuture.get().getResult().getPreviouslyInvalidatedTokens().size(), equalTo(0));
assertThat(invalidateResponseFuture.get().getResult().getErrors().size(), equalTo(0));
ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, () -> {
PlainActionFuture<AuthenticateResponse> responseFuture = new PlainActionFuture<>();

View File

@ -48,6 +48,7 @@ import org.elasticsearch.xpack.core.XPackSettings;
import org.elasticsearch.xpack.core.security.authc.Authentication;
import org.elasticsearch.xpack.core.security.authc.Authentication.RealmRef;
import org.elasticsearch.xpack.core.security.authc.TokenMetaData;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import org.elasticsearch.xpack.core.security.user.User;
import org.elasticsearch.xpack.core.watcher.watch.ClockMock;
import org.elasticsearch.xpack.security.support.SecurityIndexManager;
@ -523,7 +524,7 @@ public class TokenServiceTests extends ESTestCase {
assertNull(future.get());
e = expectThrows(IllegalStateException.class, () -> {
PlainActionFuture<Boolean> invalidateFuture = new PlainActionFuture<>();
PlainActionFuture<TokensInvalidationResult> invalidateFuture = new PlainActionFuture<>();
tokenService.invalidateAccessToken((String) null, invalidateFuture);
invalidateFuture.actionGet();
});

View File

@ -0,0 +1,74 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.authc.support;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.common.xcontent.json.JsonXContent;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.security.authc.support.TokensInvalidationResult;
import java.util.Arrays;
import java.util.Collections;
import static org.hamcrest.Matchers.equalTo;
public class TokensInvalidationResultTests extends ESTestCase {
public void testToXcontent() throws Exception{
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"),
Arrays.asList("token3", "token4"),
Arrays.asList(new ElasticsearchException("foo", new IllegalStateException("bar")),
new ElasticsearchException("boo", new IllegalStateException("far"))),
randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo(
"{\"created\":false," +
"\"invalidated_tokens\":2," +
"\"previously_invalidated_tokens\":2," +
"\"error_count\":2," +
"\"error_details\":[" +
"{\"type\":\"exception\"," +
"\"reason\":\"foo\"," +
"\"caused_by\":{" +
"\"type\":\"illegal_state_exception\"," +
"\"reason\":\"bar\"" +
"}" +
"}," +
"{\"type\":\"exception\"," +
"\"reason\":\"boo\"," +
"\"caused_by\":{" +
"\"type\":\"illegal_state_exception\"," +
"\"reason\":\"far\"" +
"}" +
"}" +
"]" +
"}"));
}
}
public void testToXcontentWithNoErrors() throws Exception{
TokensInvalidationResult result = new TokensInvalidationResult(Arrays.asList("token1", "token2"),
Collections.emptyList(),
Collections.emptyList(), randomIntBetween(0, 5));
try (XContentBuilder builder = JsonXContent.contentBuilder()) {
result.toXContent(builder, ToXContent.EMPTY_PARAMS);
assertThat(Strings.toString(builder),
equalTo(
"{\"created\":true," +
"\"invalidated_tokens\":2," +
"\"previously_invalidated_tokens\":0," +
"\"error_count\":0" +
"}"));
}
}
}

View File

@ -0,0 +1,61 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/
package org.elasticsearch.xpack.security.rest.action.oauth2;
import org.elasticsearch.common.xcontent.DeprecationHandler;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
import org.elasticsearch.common.xcontent.XContentParser;
import org.elasticsearch.common.xcontent.XContentType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.core.security.action.token.InvalidateTokenRequest;
import static org.hamcrest.Matchers.containsString;
public class RestInvalidateTokenActionTests extends ESTestCase {
public void testParserForUserAndRealm() throws Exception {
final String request = "{" +
"\"username\": \"user1\"," +
"\"realm_name\": \"realm1\"" +
"}";
try (XContentParser parser = XContentType.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) {
InvalidateTokenRequest invalidateTokenRequest = RestInvalidateTokenAction.PARSER.parse(parser, null);
assertEquals("user1", invalidateTokenRequest.getUserName());
assertEquals("realm1", invalidateTokenRequest.getRealmName());
assertNull(invalidateTokenRequest.getTokenString());
assertNull(invalidateTokenRequest.getTokenType());
}
}
public void testParserForToken() throws Exception {
final String request = "{" +
"\"refresh_token\": \"refresh_token_string\"" +
"}";
try (XContentParser parser = XContentType.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) {
InvalidateTokenRequest invalidateTokenRequest = RestInvalidateTokenAction.PARSER.parse(parser, null);
assertEquals("refresh_token_string", invalidateTokenRequest.getTokenString());
assertEquals("refresh_token", invalidateTokenRequest.getTokenType().getValue());
assertNull(invalidateTokenRequest.getRealmName());
assertNull(invalidateTokenRequest.getUserName());
}
}
public void testParserForIncorrectInput() throws Exception {
final String request = "{" +
"\"refresh_token\": \"refresh_token_string\"," +
"\"token\": \"access_token_string\"" +
"}";
try (XContentParser parser = XContentType.JSON.xContent()
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, request)) {
IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> RestInvalidateTokenAction.PARSER.parse(parser,
null));
assertThat(e.getCause().getMessage(), containsString("only one of [token, refresh_token] may be sent per request"));
}
}
}

View File

@ -5,7 +5,7 @@ setup:
- do:
cluster.health:
wait_for_status: yellow
wait_for_status: yellow
- do:
security.put_user:
@ -79,7 +79,93 @@ teardown:
body:
token: $token
- match: { created: true }
- match: { created: true}
- match: { invalidated_tokens: 1 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }
- do:
catch: unauthorized
headers:
Authorization: Bearer ${token}
security.authenticate: {}
---
"Test invalidate user's tokens":
- do:
security.get_token:
body:
grant_type: "password"
username: "token_user"
password: "x-pack-test-password"
- match: { type: "Bearer" }
- is_true: access_token
- set: { access_token: token }
- match: { expires_in: 1200 }
- is_false: scope
- do:
headers:
Authorization: Bearer ${token}
security.authenticate: {}
- match: { username: "token_user" }
- match: { roles.0: "superuser" }
- match: { full_name: "Token User" }
- do:
security.invalidate_token:
body:
username: "token_user"
- match: { created: true}
- match: { invalidated_tokens: 2 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }
- do:
catch: unauthorized
headers:
Authorization: Bearer ${token}
security.authenticate: {}
---
"Test invalidate realm's tokens":
- do:
security.get_token:
body:
grant_type: "password"
username: "token_user"
password: "x-pack-test-password"
- match: { type: "Bearer" }
- is_true: access_token
- set: { access_token: token }
- match: { expires_in: 1200 }
- is_false: scope
- do:
headers:
Authorization: Bearer ${token}
security.authenticate: {}
- match: { username: "token_user" }
- match: { roles.0: "superuser" }
- match: { full_name: "Token User" }
- do:
security.invalidate_token:
body:
realm_name: "default_native"
- match: { created: true}
- match: { invalidated_tokens: 2 }
- match: { previously_invalidated_tokens: 0 }
- match: { error_count: 0 }
- do:
catch: unauthorized