diff --git a/docs/en/rest-api/security/tokens.asciidoc b/docs/en/rest-api/security/tokens.asciidoc index 571cc3fc623..cb1b39701a6 100644 --- a/docs/en/rest-api/security/tokens.asciidoc +++ b/docs/en/rest-api/security/tokens.asciidoc @@ -74,10 +74,12 @@ seconds) that the token expires in, and the type: { "access_token" : "dGhpcyBpcyBub3QgYSByZWFsIHRva2VuIGJ1dCBpdCBpcyBvbmx5IHRlc3QgZGF0YS4gZG8gbm90IHRyeSB0byByZWFkIHRva2VuIQ==", "type" : "Bearer", - "expires_in" : 1200 + "expires_in" : 1200, + "refresh_token": "vLBPvmAB6KvwvJZr27cS" } -------------------------------------------------- // TESTRESPONSE[s/dGhpcyBpcyBub3QgYSByZWFsIHRva2VuIGJ1dCBpdCBpcyBvbmx5IHRlc3QgZGF0YS4gZG8gbm90IHRyeSB0byByZWFkIHRva2VuIQ==/$body.access_token/] +// TESTRESPONSE[s/vLBPvmAB6KvwvJZr27cS/$body.refresh_token/] The token returned by this API can be used by sending a request with a `Authorization` header with a value having the prefix `Bearer ` followed @@ -88,6 +90,36 @@ by the value of the `access_token`. curl -H "Authorization: Bearer dGhpcyBpcyBub3QgYSByZWFsIHRva2VuIGJ1dCBpdCBpcyBvbmx5IHRlc3QgZGF0YS4gZG8gbm90IHRyeSB0byByZWFkIHRva2VuIQ==" http://localhost:9200/_cluster/health -------------------------------------------------- +[[security-api-refresh-token]] +To extend the life of an existing token, the token api may be called again with the refresh +token within 24 hours of the token's creation. + +[source,js] +-------------------------------------------------- +POST /_xpack/security/oauth2/token +{ + "grant_type": "refresh_token", + "refresh_token": "vLBPvmAB6KvwvJZr27cS" +} +-------------------------------------------------- +// CONSOLE +// TEST[s/vLBPvmAB6KvwvJZr27cS/$body.refresh_token/] +// TEST[continued] + +The API will return a new token and refresh token. Each refresh token may only be used one time. + +[source,js] +-------------------------------------------------- +{ + "access_token" : "dGhpcyBpcyBub3QgYSByZWFsIHRva2VuIGJ1dCBpdCBpcyBvbmx5IHRlc3QgZGF0YS4gZG8gbm90IHRyeSB0byByZWFkIHRva2VuIQ==", + "type" : "Bearer", + "expires_in" : 1200, + "refresh_token": "vLBPvmAB6KvwvJZr27cS" +} +-------------------------------------------------- +// TESTRESPONSE[s/dGhpcyBpcyBub3QgYSByZWFsIHRva2VuIGJ1dCBpdCBpcyBvbmx5IHRlc3QgZGF0YS4gZG8gbm90IHRyeSB0byByZWFkIHRva2VuIQ==/$body.access_token/] +// TESTRESPONSE[s/vLBPvmAB6KvwvJZr27cS/$body.refresh_token/] + [[security-api-invalidate-token]] If a token must be invalidated immediately, you can do so by submitting a DELETE request to `/_xpack/security/oauth2/token`. For example: diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenAction.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenAction.java index a99c98aed78..9eb72a09077 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenAction.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenAction.java @@ -22,7 +22,7 @@ public final class CreateTokenAction extends Actionpassword grant type. + * fields for an OAuth 2.0 access token request that uses the password grant type or the + * refresh_token grant type. */ public final class CreateTokenRequest extends ActionRequest { @@ -29,27 +33,47 @@ public final class CreateTokenRequest extends ActionRequest { private String username; private SecureString password; private String scope; + private String refreshToken; CreateTokenRequest() {} - public CreateTokenRequest(String grantType, String username, SecureString password, @Nullable String scope) { + public CreateTokenRequest(String grantType, @Nullable String username, @Nullable SecureString password, @Nullable String scope, + @Nullable String refreshToken) { this.grantType = grantType; this.username = username; this.password = password; this.scope = scope; + this.refreshToken = refreshToken; } @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if ("password".equals(grantType) == false) { - validationException = addValidationError("only [password] grant_type is supported", validationException); - } - if (Strings.isNullOrEmpty(username)) { - validationException = addValidationError("username is missing", validationException); - } - if (password == null || password.getChars() == null || password.getChars().length == 0) { - validationException = addValidationError("password is missing", validationException); + if ("password".equals(grantType)) { + if (Strings.isNullOrEmpty(username)) { + validationException = addValidationError("username is missing", validationException); + } + if (password == null || password.getChars() == null || password.getChars().length == 0) { + validationException = addValidationError("password is missing", validationException); + } + if (refreshToken != null) { + validationException = + addValidationError("refresh_token is not supported with the password grant_type", validationException); + } + } else if ("refresh_token".equals(grantType)) { + if (username != null) { + validationException = + addValidationError("username is not supported with the refresh_token grant_type", validationException); + } + if (password != null) { + validationException = + addValidationError("password is not supported with the refresh_token grant_type", validationException); + } + if (refreshToken == null) { + validationException = addValidationError("refresh_token is missing", validationException); + } + } else { + validationException = addValidationError("grant_type only supports the values: [password, refresh_token]", validationException); } return validationException; @@ -59,11 +83,11 @@ public final class CreateTokenRequest extends ActionRequest { this.grantType = grantType; } - public void setUsername(String username) { + public void setUsername(@Nullable String username) { this.username = username; } - public void setPassword(SecureString password) { + public void setPassword(@Nullable SecureString password) { this.password = password; } @@ -71,14 +95,20 @@ public final class CreateTokenRequest extends ActionRequest { this.scope = scope; } + public void setRefreshToken(@Nullable String refreshToken) { + this.refreshToken = refreshToken; + } + public String getGrantType() { return grantType; } + @Nullable public String getUsername() { return username; } + @Nullable public SecureString getPassword() { return password; } @@ -88,16 +118,40 @@ public final class CreateTokenRequest extends ActionRequest { return scope; } + @Nullable + public String getRefreshToken() { + return refreshToken; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(grantType); - out.writeString(username); - final byte[] passwordBytes = CharArrays.toUtf8Bytes(password.getChars()); - try { - out.writeByteArray(passwordBytes); - } finally { - Arrays.fill(passwordBytes, (byte) 0); + if (out.getVersion().onOrAfter(Version.V_6_2_0)) { + out.writeOptionalString(username); + if (password == null) { + out.writeOptionalBytesReference(null); + } else { + final byte[] passwordBytes = CharArrays.toUtf8Bytes(password.getChars()); + try { + out.writeOptionalBytesReference(new BytesArray(passwordBytes)); + } finally { + Arrays.fill(passwordBytes, (byte) 0); + } + } + out.writeOptionalString(refreshToken); + } else { + if ("refresh_token".equals(grantType)) { + throw new UnsupportedOperationException("a refresh request cannot be sent to an older version"); + } else { + out.writeString(username); + final byte[] passwordBytes = CharArrays.toUtf8Bytes(password.getChars()); + try { + out.writeByteArray(passwordBytes); + } finally { + Arrays.fill(passwordBytes, (byte) 0); + } + } } out.writeOptionalString(scope); } @@ -106,12 +160,28 @@ public final class CreateTokenRequest extends ActionRequest { public void readFrom(StreamInput in) throws IOException { super.readFrom(in); grantType = in.readString(); - username = in.readString(); - final byte[] passwordBytes = in.readByteArray(); - try { - password = new SecureString(CharArrays.utf8BytesToChars(passwordBytes)); - } finally { - Arrays.fill(passwordBytes, (byte) 0); + if (in.getVersion().onOrAfter(Version.V_6_2_0)) { + username = in.readOptionalString(); + BytesReference bytesRef = in.readOptionalBytesReference(); + if (bytesRef != null) { + byte[] bytes = BytesReference.toBytes(bytesRef); + try { + password = new SecureString(CharArrays.utf8BytesToChars(bytes)); + } finally { + Arrays.fill(bytes, (byte) 0); + } + } else { + password = null; + } + refreshToken = in.readOptionalString(); + } else { + username = in.readString(); + final byte[] passwordBytes = in.readByteArray(); + try { + password = new SecureString(CharArrays.utf8BytesToChars(passwordBytes)); + } finally { + Arrays.fill(passwordBytes, (byte) 0); + } } scope = in.readOptionalString(); } diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestBuilder.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestBuilder.java index 74304470ff2..ac7bdf9d8e7 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestBuilder.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestBuilder.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.action.token; +import org.elasticsearch.action.Action; import org.elasticsearch.action.ActionRequestBuilder; import org.elasticsearch.client.ElasticsearchClient; import org.elasticsearch.common.Nullable; @@ -16,8 +17,9 @@ import org.elasticsearch.common.settings.SecureString; public final class CreateTokenRequestBuilder extends ActionRequestBuilder { - public CreateTokenRequestBuilder(ElasticsearchClient client) { - super(client, CreateTokenAction.INSTANCE, new CreateTokenRequest()); + public CreateTokenRequestBuilder(ElasticsearchClient client, + Action action) { + super(client, action, new CreateTokenRequest()); } /** @@ -31,7 +33,7 @@ public final class CreateTokenRequestBuilder /** * Set the username to be used for authentication with a password grant */ - public CreateTokenRequestBuilder setUsername(String username) { + public CreateTokenRequestBuilder setUsername(@Nullable String username) { request.setUsername(username); return this; } @@ -40,7 +42,7 @@ public final class CreateTokenRequestBuilder * Set the password credentials associated with the user. These credentials will be used for * authentication and the resulting token will be for this user */ - public CreateTokenRequestBuilder setPassword(SecureString password) { + public CreateTokenRequestBuilder setPassword(@Nullable SecureString password) { request.setPassword(password); return this; } @@ -54,4 +56,9 @@ public final class CreateTokenRequestBuilder request.setScope(scope); return this; } + + public CreateTokenRequestBuilder setRefreshToken(@Nullable String refreshToken) { + request.setRefreshToken(refreshToken); + return this; + } } diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenResponse.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenResponse.java index a7d9174d462..0fe6d5f729c 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenResponse.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/CreateTokenResponse.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.action.token; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -25,13 +26,15 @@ public final class CreateTokenResponse extends ActionResponse implements ToXCont private String tokenString; private TimeValue expiresIn; private String scope; + private String refreshToken; CreateTokenResponse() {} - public CreateTokenResponse(String tokenString, TimeValue expiresIn, String scope) { + public CreateTokenResponse(String tokenString, TimeValue expiresIn, String scope, String refreshToken) { this.tokenString = Objects.requireNonNull(tokenString); this.expiresIn = Objects.requireNonNull(expiresIn); this.scope = scope; + this.refreshToken = refreshToken; } public String getTokenString() { @@ -46,12 +49,19 @@ public final class CreateTokenResponse extends ActionResponse implements ToXCont return expiresIn; } + public String getRefreshToken() { + return refreshToken; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(tokenString); expiresIn.writeTo(out); out.writeOptionalString(scope); + if (out.getVersion().onOrAfter(Version.V_6_2_0)) { + out.writeString(refreshToken); + } } @Override @@ -60,6 +70,9 @@ public final class CreateTokenResponse extends ActionResponse implements ToXCont tokenString = in.readString(); expiresIn = new TimeValue(in); scope = in.readOptionalString(); + if (in.getVersion().onOrAfter(Version.V_6_2_0)) { + refreshToken = in.readString(); + } } @Override @@ -68,6 +81,9 @@ public final class CreateTokenResponse extends ActionResponse implements ToXCont .field("access_token", tokenString) .field("type", "Bearer") .field("expires_in", expiresIn.seconds()); + if (refreshToken != null) { + builder.field("refresh_token", refreshToken); + } // only show the scope if it is not null if (scope != null) { builder.field("scope", scope); diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequest.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequest.java index 9fe57757b7a..0e37259247e 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequest.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequest.java @@ -5,6 +5,7 @@ */ package org.elasticsearch.xpack.security.action.token; +import org.elasticsearch.Version; import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.Strings; @@ -20,15 +21,22 @@ import static org.elasticsearch.action.ValidateActions.addValidationError; */ public final class InvalidateTokenRequest extends ActionRequest { + public enum Type { + ACCESS_TOKEN, + REFRESH_TOKEN + } + private String tokenString; + private Type tokenType; InvalidateTokenRequest() {} /** * @param tokenString the string representation of the token */ - public InvalidateTokenRequest(String tokenString) { + public InvalidateTokenRequest(String tokenString, Type type) { this.tokenString = tokenString; + this.tokenType = type; } @Override @@ -37,6 +45,9 @@ public final class InvalidateTokenRequest extends ActionRequest { if (Strings.isNullOrEmpty(tokenString)) { validationException = addValidationError("token string must be provided", null); } + if (tokenType == null) { + validationException = addValidationError("token type must be provided", validationException); + } return validationException; } @@ -48,15 +59,34 @@ public final class InvalidateTokenRequest extends ActionRequest { this.tokenString = token; } + Type getTokenType() { + return tokenType; + } + + void setTokenType(Type tokenType) { + this.tokenType = tokenType; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(tokenString); + if (out.getVersion().onOrAfter(Version.V_6_2_0)) { + out.writeVInt(tokenType.ordinal()); + } else if (tokenType == Type.REFRESH_TOKEN) { + throw new UnsupportedOperationException("refresh token invalidation cannot be serialized with version [" + out.getVersion() + + "]"); + } } @Override public void readFrom(StreamInput in) throws IOException { super.readFrom(in); tokenString = in.readString(); + if (in.getVersion().onOrAfter(Version.V_6_2_0)) { + tokenType = Type.values()[in.readVInt()]; + } else { + tokenType = Type.ACCESS_TOKEN; + } } } diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequestBuilder.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequestBuilder.java index ead90673d17..71e2110b892 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequestBuilder.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/InvalidateTokenRequestBuilder.java @@ -26,4 +26,12 @@ public final class InvalidateTokenRequestBuilder request.setTokenString(token); return this; } + + /** + * Sets the type of the token that should be invalidated + */ + public InvalidateTokenRequestBuilder setType(InvalidateTokenRequest.Type type) { + request.setTokenType(type); + return this; + } } diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/RefreshTokenAction.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/RefreshTokenAction.java new file mode 100644 index 00000000000..c6384950bca --- /dev/null +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/token/RefreshTokenAction.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security.action.token; + +import org.elasticsearch.action.Action; +import org.elasticsearch.client.ElasticsearchClient; + +public final class RefreshTokenAction extends Action { + + public static final String NAME = "cluster:admin/xpack/security/token/refresh"; + public static final RefreshTokenAction INSTANCE = new RefreshTokenAction(); + + private RefreshTokenAction() { + super(NAME); + } + + @Override + public CreateTokenRequestBuilder newRequestBuilder(ElasticsearchClient client) { + return new CreateTokenRequestBuilder(client, INSTANCE); + } + + @Override + public CreateTokenResponse newResponse() { + return new CreateTokenResponse(); + } +} diff --git a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/user/AuthenticateRequest.java b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/user/AuthenticateRequest.java index ea92201b244..1b65cd1282d 100644 --- a/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/user/AuthenticateRequest.java +++ b/plugin/core/src/main/java/org/elasticsearch/xpack/security/action/user/AuthenticateRequest.java @@ -9,12 +9,9 @@ import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.xpack.security.support.Validation; import java.io.IOException; -import static org.elasticsearch.action.ValidateActions.addValidationError; - public class AuthenticateRequest extends ActionRequest implements UserRequest { private String username; diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/Security.java b/plugin/src/main/java/org/elasticsearch/xpack/security/Security.java index e24b0c7955a..b14c7e7442b 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/Security.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/Security.java @@ -106,8 +106,10 @@ import org.elasticsearch.xpack.security.action.rolemapping.TransportGetRoleMappi import org.elasticsearch.xpack.security.action.rolemapping.TransportPutRoleMappingAction; import org.elasticsearch.xpack.security.action.token.CreateTokenAction; import org.elasticsearch.xpack.security.action.token.InvalidateTokenAction; +import org.elasticsearch.xpack.security.action.token.RefreshTokenAction; import org.elasticsearch.xpack.security.action.token.TransportCreateTokenAction; import org.elasticsearch.xpack.security.action.token.TransportInvalidateTokenAction; +import org.elasticsearch.xpack.security.action.token.TransportRefreshTokenAction; import org.elasticsearch.xpack.security.action.user.AuthenticateAction; import org.elasticsearch.xpack.security.action.user.ChangePasswordAction; import org.elasticsearch.xpack.security.action.user.DeleteUserAction; @@ -609,7 +611,8 @@ public class Security implements ActionPlugin, IngestPlugin, NetworkPlugin, Clus new ActionHandler<>(DeleteRoleMappingAction.INSTANCE, TransportDeleteRoleMappingAction.class), new ActionHandler<>(CreateTokenAction.INSTANCE, TransportCreateTokenAction.class), new ActionHandler<>(InvalidateTokenAction.INSTANCE, TransportInvalidateTokenAction.class), - new ActionHandler<>(GetCertificateInfoAction.INSTANCE, TransportGetCertificateInfoAction.class) + new ActionHandler<>(GetCertificateInfoAction.INSTANCE, TransportGetCertificateInfoAction.class), + new ActionHandler<>(RefreshTokenAction.INSTANCE, TransportRefreshTokenAction.class) ); } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java b/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java index 7ebc5696bfa..ee417c56064 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportCreateTokenAction.java @@ -10,14 +10,15 @@ import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.common.inject.Inject; -import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.security.authc.Authentication; import org.elasticsearch.xpack.security.authc.AuthenticationService; import org.elasticsearch.xpack.security.authc.TokenService; -import org.elasticsearch.xpack.security.authc.UserToken; + +import java.util.Collections; /** * Transport action responsible for creating a token based on a request. Requests provide user @@ -43,27 +44,24 @@ public final class TransportCreateTokenAction extends HandledTransportAction listener) { + Authentication originatingAuthentication = Authentication.getAuthentication(threadPool.getThreadContext()); try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { authenticationService.authenticate(CreateTokenAction.NAME, request, request.getUsername(), request.getPassword(), ActionListener.wrap(authentication -> { - try (SecureString ignore1 = request.getPassword()) { - final UserToken token = tokenService.createUserToken(authentication); - final String tokenStr = tokenService.getUserTokenString(token); - final String scope; - // the OAuth2.0 RFC requires the scope to be provided in the - // response if it differs from the user provided scope. If the - // scope was not provided then it does not need to be returned. - // if the scope is not supported, the value of the scope that the - // token is for must be returned - if (request.getScope() != null) { - scope = DEFAULT_SCOPE; // this is the only non-null value that is currently supported - } else { - scope = null; - } + request.getPassword().close(); + tokenService.createUserToken(authentication, originatingAuthentication, ActionListener.wrap(tuple -> { + final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String scope = getResponseScopeValue(request.getScope()); - listener.onResponse(new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope)); - } + final CreateTokenResponse response = + new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope, tuple.v2()); + listener.onResponse(response); + }, e -> { + // clear the request password + request.getPassword().close(); + listener.onFailure(e); + }), Collections.emptyMap()); }, e -> { // clear the request password request.getPassword().close(); @@ -71,4 +69,19 @@ public final class TransportCreateTokenAction extends HandledTransportAction listener) { - tokenService.invalidateToken(request.getTokenString(), ActionListener.wrap( - created -> listener.onResponse(new InvalidateTokenResponse(created)), - listener::onFailure)); + protected void doExecute(InvalidateTokenRequest request, ActionListener listener) { + final ActionListener invalidateListener = + ActionListener.wrap(created -> listener.onResponse(new InvalidateTokenResponse(created)), listener::onFailure); + if (request.getTokenType() == InvalidateTokenRequest.Type.ACCESS_TOKEN) { + tokenService.invalidateAccessToken(request.getTokenString(), invalidateListener); + } else { + assert request.getTokenType() == InvalidateTokenRequest.Type.REFRESH_TOKEN; + tokenService.invalidateRefreshToken(request.getTokenString(), invalidateListener); + } } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java b/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java new file mode 100644 index 00000000000..f03810bdb06 --- /dev/null +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/action/token/TransportRefreshTokenAction.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.xpack.security.action.token; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.common.inject.Inject; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.security.authc.TokenService; + +import static org.elasticsearch.xpack.security.action.token.TransportCreateTokenAction.getResponseScopeValue; + +public class TransportRefreshTokenAction extends HandledTransportAction { + + private final TokenService tokenService; + + @Inject + public TransportRefreshTokenAction(Settings settings, ThreadPool threadPool, TransportService transportService, + ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, + TokenService tokenService) { + super(settings, RefreshTokenAction.NAME, threadPool, transportService, actionFilters, indexNameExpressionResolver, + CreateTokenRequest::new); + this.tokenService = tokenService; + } + + @Override + protected void doExecute(CreateTokenRequest request, ActionListener listener) { + tokenService.refreshToken(request.getRefreshToken(), ActionListener.wrap(tuple -> { + final String tokenStr = tokenService.getUserTokenString(tuple.v1()); + final String scope = getResponseScopeValue(request.getScope()); + + final CreateTokenResponse response = + new CreateTokenResponse(tokenStr, tokenService.getExpirationDelay(), scope, tuple.v2()); + listener.onResponse(response); + }, listener::onFailure)); + } +} diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/ExpiredTokenRemover.java b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/ExpiredTokenRemover.java index e36f32eaf2a..7cce0e14742 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/ExpiredTokenRemover.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/ExpiredTokenRemover.java @@ -25,6 +25,7 @@ import org.elasticsearch.threadpool.ThreadPool.Names; import org.elasticsearch.xpack.security.SecurityLifecycleService; import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException; @@ -50,25 +51,23 @@ final class ExpiredTokenRemover extends AbstractRunnable { @Override public void doRun() { SearchRequest searchRequest = new SearchRequest(SecurityLifecycleService.SECURITY_INDEX_NAME); - DeleteByQueryRequest dbq = new DeleteByQueryRequest(searchRequest); + DeleteByQueryRequest expiredDbq = new DeleteByQueryRequest(searchRequest); if (timeout != TimeValue.MINUS_ONE) { - dbq.setTimeout(timeout); + expiredDbq.setTimeout(timeout); searchRequest.source().timeout(timeout); } + final Instant now = Instant.now(); searchRequest.source() .query(QueryBuilders.boolQuery() - .filter(QueryBuilders.termQuery("doc_type", TokenService.DOC_TYPE)) - .filter(QueryBuilders.rangeQuery("expiration_time").lte(Instant.now().toEpochMilli()))); - executeAsyncWithOrigin(client, SECURITY_ORIGIN, DeleteByQueryAction.INSTANCE, dbq, + .filter(QueryBuilders.termsQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE, "token")) + .filter(QueryBuilders.boolQuery() + .should(QueryBuilders.rangeQuery("expiration_time").lte(now.toEpochMilli())) + .should(QueryBuilders.rangeQuery("creation_time").lte(now.minus(24L, ChronoUnit.HOURS).toEpochMilli())))); + executeAsyncWithOrigin(client, SECURITY_ORIGIN, DeleteByQueryAction.INSTANCE, expiredDbq, ActionListener.wrap(r -> { debugDbqResponse(r); markComplete(); - }, e -> { - if (isShardNotAvailableException(e) == false) { - logger.error("failed to delete expired tokens", e); - } - markComplete(); - })); + }, this::onFailure)); } void submit(ThreadPool threadPool) { @@ -98,7 +97,11 @@ final class ExpiredTokenRemover extends AbstractRunnable { @Override public void onFailure(Exception e) { - logger.error("failed to delete expired tokens", e); + if (isShardNotAvailableException(e)) { + logger.debug("failed to delete expired tokens", e); + } else { + logger.error("failed to delete expired tokens", e); + } markComplete(); } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java index d150c1f6089..782b90b1a1d 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/TokenService.java @@ -12,15 +12,24 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.UnicodeUtil; import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.DocWriteRequest.OpType; -import org.elasticsearch.action.DocWriteResponse.Result; +import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.get.MultiGetItemResponse; +import org.elasticsearch.action.get.MultiGetRequest; +import org.elasticsearch.action.get.MultiGetResponse; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.index.IndexResponse; -import org.elasticsearch.action.support.TransportActions; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.support.WriteRequest.RefreshPolicy; import org.elasticsearch.action.support.master.AcknowledgedRequest; +import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.action.update.UpdateResponse; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.AckedClusterStateUpdateTask; import org.elasticsearch.cluster.ClusterState; @@ -28,9 +37,12 @@ import org.elasticsearch.cluster.ack.AckedRequest; import org.elasticsearch.cluster.ack.ClusterStateUpdateResponse; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.cache.Cache; import org.elasticsearch.common.cache.CacheBuilder; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.component.AbstractComponent; +import org.elasticsearch.common.hash.MessageDigests; import org.elasticsearch.common.io.stream.InputStreamStreamInput; import org.elasticsearch.common.io.stream.OutputStreamStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; @@ -44,7 +56,11 @@ import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.common.util.iterable.Iterables; +import org.elasticsearch.common.xcontent.XContentBuilder; +import org.elasticsearch.common.xcontent.XContentFactory; +import org.elasticsearch.index.engine.DocumentMissingException; import org.elasticsearch.index.engine.VersionConflictEngineException; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.XPackSettings; import org.elasticsearch.xpack.XpackField; @@ -74,6 +90,7 @@ import java.security.spec.InvalidKeySpecException; import java.time.Clock; import java.time.Instant; import java.time.ZoneOffset; +import java.time.temporal.ChronoUnit; import java.util.ArrayList; import java.util.Arrays; import java.util.Base64; @@ -82,9 +99,12 @@ import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import static org.elasticsearch.action.support.TransportActions.isShardNotAvailableException; import static org.elasticsearch.gateway.GatewayService.STATE_NOT_RECOVERED_BLOCK; import static org.elasticsearch.xpack.ClientHelper.SECURITY_ORIGIN; import static org.elasticsearch.xpack.ClientHelper.executeAsyncWithOrigin; @@ -123,9 +143,9 @@ public final class TokenService extends AbstractComponent { public static final Setting DELETE_TIMEOUT = Setting.timeSetting("xpack.security.authc.token.delete.timeout", TimeValue.MINUS_ONE, Property.NodeScope); - static final String DOC_TYPE = "invalidated-token"; + static final String INVALIDATED_TOKEN_DOC_TYPE = "invalidated-token"; static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1; - static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); + private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue(); private final SecureRandom secureRandom = new SecureRandom(); private final ClusterService clusterService; @@ -136,11 +156,9 @@ public final class TokenService extends AbstractComponent { private final SecurityLifecycleService lifecycleService; private final ExpiredTokenRemover expiredTokenRemover; private final boolean enabled; - private final byte[] currentVersionBytes; private volatile TokenKeys keyCache; private volatile long lastExpirationRunMs; private final AtomicLong createdTimeStamps = new AtomicLong(-1); - private static final Version TOKEN_SERVICE_VERSION = Version.CURRENT; /** * Creates a new token service @@ -163,9 +181,8 @@ public final class TokenService extends AbstractComponent { this.deleteInterval = DELETE_INTERVAL.get(settings); this.enabled = XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.get(settings); this.expiredTokenRemover = new ExpiredTokenRemover(settings, client); - this.currentVersionBytes = ByteBuffer.allocate(4).putInt(TOKEN_SERVICE_VERSION.id).array(); ensureEncryptionCiphersSupported(); - KeyAndCache keyAndCache = new KeyAndCache(new KeyAndTimestamp(tokenPassphrase.clone(), createdTimeStamps.incrementAndGet()), + KeyAndCache keyAndCache = new KeyAndCache(new KeyAndTimestamp(tokenPassphrase, createdTimeStamps.incrementAndGet()), new BytesKey(saltArr)); keyCache = new TokenKeys(Collections.singletonMap(keyAndCache.getKeyHash(), keyAndCache), keyAndCache.getKeyHash()); this.clusterService = clusterService; @@ -175,17 +192,62 @@ public final class TokenService extends AbstractComponent { /** - * Create a token based on the provided authentication + * Create a token based on the provided authentication and metadata. + * The created token will be stored in the security index. */ - public UserToken createUserToken(Authentication authentication) - throws IOException, GeneralSecurityException { + public void createUserToken(Authentication authentication, Authentication originatingClientAuth, + ActionListener> listener, Map metadata) throws IOException { ensureEnabled(); - final Instant expiration = getExpirationTime(); - return new UserToken(authentication, expiration); + if (authentication == null) { + listener.onFailure(new IllegalArgumentException("authentication must be provided")); + } else { + final Instant created = clock.instant(); + final Instant expiration = getExpirationTime(created); + final Version version = clusterService.state().nodes().getMinNodeVersion(); + final Authentication matchingVersionAuth = version.equals(authentication.getVersion()) ? authentication : + new Authentication(authentication.getUser(), authentication.getAuthenticatedBy(), authentication.getLookedUpBy(), + version); + final UserToken userToken = new UserToken(version, matchingVersionAuth, expiration, metadata); + final String refreshToken = UUIDs.randomBase64UUID(); + + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + builder.startObject(); + builder.field("doc_type", "token"); + builder.field("creation_time", created.toEpochMilli()); + builder.startObject("refresh_token") + .field("token", refreshToken) + .field("invalidated", false) + .field("refreshed", false) + .startObject("client") + .field("type", "unassociated_client") + .field("user", originatingClientAuth.getUser().principal()) + .field("realm", originatingClientAuth.getAuthenticatedBy().getName()) + .endObject() + .endObject(); + builder.startObject("access_token") + .field("invalidated", false) + .field("user_token", userToken) + .endObject(); + builder.endObject(); + IndexRequest request = + client.prepareIndex(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, getTokenDocumentId(userToken)) + .setOpType(OpType.CREATE) + .setSource(builder) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .request(); + lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client, SECURITY_ORIGIN, IndexAction.INSTANCE, request, + ActionListener.wrap(indexResponse -> listener.onResponse(new Tuple<>(userToken, refreshToken)), + listener::onFailure)) + ); + } + } } /** - * Looks in the context to see if the request provided a header with a user token + * Looks in the context to see if the request provided a header with a user token and if so the + * token is validated, which includes authenticated decryption and verification that the token + * has not been revoked or is expired. */ void getAndValidateToken(ThreadContext ctx, ActionListener listener) { if (enabled) { @@ -218,6 +280,13 @@ public final class TokenService extends AbstractComponent { } } + /** + * Asynchronously decodes the string representation of a {@link UserToken}. The process for + * this is asynchronous as we may need to compute a key, which can be computationally expensive + * so this should not block the current thread, which is typically a network thread. A second + * reason for being asynchronous is that we can restrain the amount of resources consumed by + * the key computation to a single thread. + */ void decodeToken(String token, ActionListener listener) throws IOException { // We intentionally do not use try-with resources since we need to keep the stream open if we need to compute a key! byte[] bytes = token.getBytes(StandardCharsets.UTF_8); @@ -228,36 +297,49 @@ public final class TokenService extends AbstractComponent { } else { // the token exists and the value is at least as long as we'd expect final Version version = Version.readVersion(in); + in.setVersion(version); final BytesKey decodedSalt = new BytesKey(in.readByteArray()); final BytesKey passphraseHash = new BytesKey(in.readByteArray()); KeyAndCache keyAndCache = keyCache.get(passphraseHash); if (keyAndCache != null) { - final SecretKey decodeKey = keyAndCache.getKey(decodedSalt); - final byte[] iv = in.readByteArray(); - if (decodeKey != null) { + getKeyAsync(decodedSalt, keyAndCache, ActionListener.wrap(decodeKey -> { try { + final byte[] iv = in.readByteArray(); decryptToken(in, getDecryptionCipher(iv, decodeKey, version, decodedSalt), version, listener); } catch (GeneralSecurityException e) { // could happen with a token that is not ours logger.warn("invalid token", e); listener.onResponse(null); + } finally { + in.close(); } - } else { - /* As a measure of protected against DOS, we can pass requests requiring a key - * computation off to a single thread executor. For normal usage, the initial - * request(s) that require a key computation will be delayed and there will be - * some additional latency. - */ - client.threadPool().executor(THREAD_POOL_NAME) - .submit(new KeyComputingRunnable(in, iv, version, decodedSalt, listener, keyAndCache)); - } + }, e -> { + IOUtils.closeWhileHandlingException(in); + listener.onFailure(e); + })); } else { + IOUtils.closeWhileHandlingException(in); logger.debug("invalid key {} key: {}", passphraseHash, keyCache.cache.keySet()); listener.onResponse(null); } } } + private void getKeyAsync(BytesKey decodedSalt, KeyAndCache keyAndCache, ActionListener listener) { + final SecretKey decodeKey = keyAndCache.getKey(decodedSalt); + if (decodeKey != null) { + listener.onResponse(decodeKey); + } else { + /* As a measure of protected against DOS, we can pass requests requiring a key + * computation off to a single thread executor. For normal usage, the initial + * request(s) that require a key computation will be delayed and there will be + * some additional latency. + */ + client.threadPool().executor(THREAD_POOL_NAME) + .submit(new KeyComputingRunnable(decodedSalt, listener, keyAndCache)); + } + } + private static void decryptToken(StreamInput in, Cipher cipher, Version version, ActionListener listener) throws IOException { try (CipherInputStream cis = new CipherInputStream(in, cipher); StreamInput decryptedInput = new InputStreamStreamInput(cis)) { @@ -267,9 +349,14 @@ public final class TokenService extends AbstractComponent { } /** - * This method records an entry to indicate that a token with a given id has been expired. + * This method performs the steps necessary to invalidate a token so that it may no longer be + * used. The process of invalidation involves a step that is needed for backwards compatibility + * with versions prior to 6.2.0; this step records an entry to indicate that a token with a + * given id has been expired. The second step is to record the invalidation for tokens that + * have been created on versions on or after 6.2; this step involves performing an update to + * the token document and setting the invalidated field to true */ - public void invalidateToken(String tokenString, ActionListener listener) { + public void invalidateAccessToken(String tokenString, ActionListener listener) { ensureEnabled(); if (Strings.isNullOrEmpty(tokenString)) { listener.onFailure(new IllegalArgumentException("token must be provided")); @@ -279,34 +366,9 @@ public final class TokenService extends AbstractComponent { decodeToken(tokenString, ActionListener.wrap(userToken -> { if (userToken == null) { listener.onFailure(malformedTokenException()); - } else if (userToken.getExpirationTime().isBefore(clock.instant())) { - // no need to invalidate - it's already expired - listener.onResponse(false); } else { - final String id = getDocumentId(userToken); - lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> { - executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, - client.prepareIndex(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, id) - .setOpType(OpType.CREATE) - .setSource("doc_type", DOC_TYPE, "expiration_time", getExpirationTime().toEpochMilli()) - .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL).request(), - new ActionListener() { - @Override - public void onResponse(IndexResponse indexResponse) { - listener.onResponse(indexResponse.getResult() == Result.CREATED); - } - - @Override - public void onFailure(Exception e) { - if (e instanceof VersionConflictEngineException) { - // doc already exists - listener.onResponse(false); - } else { - listener.onFailure(e); - } - } - }, client::index); - }); + final long expirationEpochMilli = getExpirationTime().toEpochMilli(); + indexBwcInvalidation(userToken, listener, new AtomicInteger(0), expirationEpochMilli); } }, listener::onFailure)); } catch (IOException e) { @@ -316,8 +378,344 @@ public final class TokenService extends AbstractComponent { } } - private static String getDocumentId(UserToken userToken) { - return DOC_TYPE + "_" + userToken.getId(); + public void invalidateRefreshToken(String refreshToken, ActionListener listener) { + ensureEnabled(); + if (Strings.isNullOrEmpty(refreshToken)) { + listener.onFailure(new IllegalArgumentException("refresh token must be provided")); + } else { + maybeStartTokenRemover(); + findTokenFromRefreshToken(refreshToken, + ActionListener.wrap(tuple -> { + final String docId = tuple.v1().getHits().getAt(0).getId(); + final long docVersion = tuple.v1().getHits().getAt(0).getVersion(); + indexInvalidation(docId, Version.CURRENT, listener, tuple.v2(), "refresh_token", docVersion); + }, listener::onFailure), new AtomicInteger(0)); + } + } + + /** + * Performs the actual bwc invalidation of a token and then kicks off the new invalidation method + * @param userToken the token to invalidate + * @param listener the listener to notify upon completion + * @param attemptCount the number of attempts to invalidate that have already been tried + * @param expirationEpochMilli the expiration time as milliseconds since the epoch + */ + private void indexBwcInvalidation(UserToken userToken, ActionListener listener, AtomicInteger attemptCount, + long expirationEpochMilli) { + if (attemptCount.get() > 5) { + listener.onFailure(invalidGrantException("failed to invalidate token")); + } else { + final String invalidatedTokenId = getInvalidatedTokenDocumentId(userToken); + IndexRequest indexRequest = client.prepareIndex(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, invalidatedTokenId) + .setOpType(OpType.CREATE) + .setSource("doc_type", INVALIDATED_TOKEN_DOC_TYPE, "expiration_time", expirationEpochMilli) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .request(); + final String tokenDocId = getTokenDocumentId(userToken); + final Version version = userToken.getVersion(); + lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, indexRequest, + ActionListener.wrap(indexResponse -> { + ActionListener wrappedListener = + ActionListener.wrap(ignore -> listener.onResponse(true), listener::onFailure); + indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L); + }, e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof VersionConflictEngineException) { + // expected since something else could have invalidated + ActionListener wrappedListener = + ActionListener.wrap(ignore -> listener.onResponse(false), listener::onFailure); + indexInvalidation(tokenDocId, version, wrappedListener, attemptCount, "access_token", 1L); + } else if (isShardNotAvailableException(e)) { + attemptCount.incrementAndGet(); + indexBwcInvalidation(userToken, listener, attemptCount, expirationEpochMilli); + } else { + listener.onFailure(e); + } + }), client::index)); + } + } + + /** + * Performs the actual invalidation of a token + * @param tokenDocId the id of the token doc to invalidate + * @param listener the listener to notify upon completion + * @param attemptCount the number of attempts to invalidate that have already been tried + * @param srcPrefix the prefix to use when constructing the doc to update + * @param documentVersion the expected version of the document we will update + */ + private void indexInvalidation(String tokenDocId, Version version, ActionListener listener, AtomicInteger attemptCount, + String srcPrefix, long documentVersion) { + if (attemptCount.get() > 5) { + listener.onFailure(invalidGrantException("failed to invalidate token")); + } else { + UpdateRequest request = client.prepareUpdate(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, tokenDocId) + .setDoc(srcPrefix, Collections.singletonMap("invalidated", true)) + .setVersion(documentVersion) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .request(); + lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, + ActionListener.wrap(updateResponse -> { + if (updateResponse.getGetResult() != null + && updateResponse.getGetResult().sourceAsMap().containsKey(srcPrefix) + && ((Map) updateResponse.getGetResult().sourceAsMap().get(srcPrefix)) + .containsKey("invalidated")) { + final boolean prevInvalidated = (boolean) + ((Map) updateResponse.getGetResult().sourceAsMap().get(srcPrefix)) + .get("invalidated"); + listener.onResponse(prevInvalidated == false); + } else { + listener.onResponse(true); + } + }, e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof DocumentMissingException) { + if (version.onOrAfter(Version.V_6_2_0)) { + // the document should always be there! + listener.onFailure(e); + } else { + listener.onResponse(false); + } + } else if (cause instanceof VersionConflictEngineException + || isShardNotAvailableException(cause)) { + attemptCount.incrementAndGet(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, + client.prepareGet(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(), + ActionListener.wrap(getResult -> { + if (getResult.isExists()) { + Map source = getResult.getSource(); + Map accessTokenSource = + (Map) source.get("access_token"); + if (accessTokenSource == null) { + listener.onFailure(new IllegalArgumentException("token document is " + + "missing access_token field")); + } else { + Boolean invalidated = (Boolean) accessTokenSource.get("invalidated"); + if (invalidated == null) { + listener.onFailure(new IllegalStateException( + "token document missing invalidated value")); + } else if (invalidated) { + listener.onResponse(false); + } else { + indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix, + getResult.getVersion()); + } + } + } else if (version.onOrAfter(Version.V_6_2_0)) { + logger.warn("could not find token document [{}] but there should " + + "be one as token has version [{}]", tokenDocId, version); + listener.onFailure(invalidGrantException("could not invalidate the token")); + } else { + listener.onResponse(false); + } + }, + e1 -> { + if (isShardNotAvailableException(e1)) { + // don't increment count; call again + indexInvalidation(tokenDocId, version, listener, attemptCount, srcPrefix, + documentVersion); + } else { + listener.onFailure(e1); + } + }), client::get); + } else { + listener.onFailure(e); + } + }), client::update)); + } + } + + /** + * Uses the refresh token to refresh its associated token and returns the new token with an + * updated expiration date to the listener + */ + public void refreshToken(String refreshToken, ActionListener> listener) { + ensureEnabled(); + findTokenFromRefreshToken(refreshToken, + ActionListener.wrap(tuple -> { + final Authentication userAuth = Authentication.readFromContext(client.threadPool().getThreadContext()); + final String tokenDocId = tuple.v1().getHits().getHits()[0].getId(); + innerRefresh(tokenDocId, userAuth, listener, tuple.v2()); + }, listener::onFailure), + new AtomicInteger(0)); + } + + private void findTokenFromRefreshToken(String refreshToken, ActionListener> listener, + AtomicInteger attemptCount) { + if (attemptCount.get() > 5) { + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else { + SearchRequest request = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME) + .setQuery(QueryBuilders.boolQuery() + .filter(QueryBuilders.termQuery("doc_type", "token")) + .filter(QueryBuilders.termQuery("refresh_token.token", refreshToken))) + .setVersion(true) + .request(); + + lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, request, + ActionListener.wrap(searchResponse -> { + if (searchResponse.isTimedOut()) { + attemptCount.incrementAndGet(); + findTokenFromRefreshToken(refreshToken, listener, attemptCount); + } else if (searchResponse.getHits().getHits().length < 1) { + logger.info("could not find token document with refresh_token [{}]", refreshToken); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else if (searchResponse.getHits().getHits().length > 1) { + listener.onFailure(new IllegalStateException("multiple tokens share the same refresh token")); + } else { + listener.onResponse(new Tuple<>(searchResponse, attemptCount)); + } + }, e -> { + if (isShardNotAvailableException(e)) { + logger.debug("failed to search for token document, retrying", e); + attemptCount.incrementAndGet(); + findTokenFromRefreshToken(refreshToken, listener, attemptCount); + } else { + listener.onFailure(e); + } + }), + client::search)); + } + } + + /** + * Performs the actual refresh of the token with retries in case of certain exceptions that + * may be recoverable. The refresh involves retrieval of the token document and then + * updating the token document to indicate that the document has been refreshed. + */ + private void innerRefresh(String tokenDocId, Authentication userAuth, ActionListener> listener, + AtomicInteger attemptCount) { + if (attemptCount.getAndIncrement() > 5) { + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } else { + GetRequest getRequest = client.prepareGet(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, tokenDocId).request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, getRequest, + ActionListener.wrap(response -> { + if (response.isExists()) { + final Map source = response.getSource(); + final Optional invalidSource = checkTokenDocForRefresh(source, userAuth); + + if (invalidSource.isPresent()) { + listener.onFailure(invalidSource.get()); + } else { + final Map userTokenSource = (Map) + ((Map) source.get("access_token")).get("user_token"); + final String authString = (String) userTokenSource.get("authentication"); + final Integer version = (Integer) userTokenSource.get("version"); + final Map metadata = (Map) userTokenSource.get("metadata"); + + Version authVersion = Version.fromId(version); + try (StreamInput in = StreamInput.wrap(Base64.getDecoder().decode(authString))) { + in.setVersion(authVersion); + Authentication authentication = new Authentication(in); + UpdateRequest updateRequest = + client.prepareUpdate(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, tokenDocId) + .setVersion(response.getVersion()) + .setDoc("refresh_token", Collections.singletonMap("refreshed", true)) + .setRefreshPolicy(RefreshPolicy.WAIT_UNTIL) + .request(); + executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, updateRequest, + ActionListener.wrap( + updateResponse -> createUserToken(authentication, userAuth, listener, metadata), + e -> { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof VersionConflictEngineException || + isShardNotAvailableException(e)) { + innerRefresh(tokenDocId, userAuth, + listener, attemptCount); + } else { + listener.onFailure(e); + } + }), + client::update); + } + } + } else { + logger.info("could not find token document [{}] for refresh", tokenDocId); + listener.onFailure(invalidGrantException("could not refresh the requested token")); + } + }, e -> { + if (isShardNotAvailableException(e)) { + innerRefresh(tokenDocId, userAuth, listener, attemptCount); + } else { + listener.onFailure(e); + } + }), client::get); + } + } + + /** + * Performs checks on the retrieved source and returns an {@link Optional} with the exception + * if there is an issue + */ + private Optional checkTokenDocForRefresh(Map source, Authentication userAuth) { + final Map refreshTokenSrc = (Map) source.get("refresh_token"); + final Map accessTokenSrc = (Map) source.get("access_token"); + if (refreshTokenSrc == null || refreshTokenSrc.isEmpty()) { + return Optional.of(invalidGrantException("token document is missing the refresh_token object")); + } else if (accessTokenSrc == null || accessTokenSrc.isEmpty()) { + return Optional.of(invalidGrantException("token document is missing the access_token object")); + } else { + final Boolean refreshed = (Boolean) refreshTokenSrc.get("refreshed"); + final Boolean invalidated = (Boolean) refreshTokenSrc.get("invalidated"); + final Long creationEpochMilli = (Long) source.get("creation_time"); + final Instant creationTime = creationEpochMilli == null ? null : Instant.ofEpochMilli(creationEpochMilli); + final Map userTokenSrc = (Map) accessTokenSrc.get("user_token"); + if (refreshed == null) { + return Optional.of(invalidGrantException("token document is missing refreshed value")); + } else if (invalidated == null) { + return Optional.of(invalidGrantException("token document is missing invalidated value")); + } else if (creationEpochMilli == null) { + return Optional.of(invalidGrantException("token document is missing creation time value")); + } else if (refreshed) { + return Optional.of(invalidGrantException("token has already been refreshed")); + } else if (invalidated) { + return Optional.of(invalidGrantException("token has been invalidated")); + } else if (clock.instant().isAfter(creationTime.plus(24L, ChronoUnit.HOURS))) { + return Optional.of(invalidGrantException("refresh token is expired")); + } else if (userTokenSrc == null || userTokenSrc.isEmpty()) { + return Optional.of(invalidGrantException("token document is missing the user token info")); + } else if (userTokenSrc.get("authentication") == null) { + return Optional.of(invalidGrantException("token is missing authentication info")); + } else if (userTokenSrc.get("version") == null) { + return Optional.of(invalidGrantException("token is missing version value")); + } else if (userTokenSrc.get("metadata") == null) { + return Optional.of(invalidGrantException("token is missing metadata")); + } else { + return checkClient(refreshTokenSrc, userAuth); + } + } + } + + private Optional checkClient(Map refreshTokenSource, Authentication userAuth) { + Map clientInfo = (Map) refreshTokenSource.get("client"); + if (clientInfo == null) { + return Optional.of(invalidGrantException("token is missing client information")); + } else if (userAuth.getUser().principal().equals(clientInfo.get("user")) == false) { + return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); + } else if (userAuth.getAuthenticatedBy().getName().equals(clientInfo.get("realm")) == false) { + return Optional.of(invalidGrantException("tokens must be refreshed by the creating client")); + } else { + return Optional.empty(); + } + } + + private static String getInvalidatedTokenDocumentId(UserToken userToken) { + return getInvalidatedTokenDocumentId(userToken.getId()); + } + + private static String getInvalidatedTokenDocumentId(String id) { + return INVALIDATED_TOKEN_DOC_TYPE + "_" + id; + } + + private static String getTokenDocumentId(UserToken userToken) { + return getTokenDocumentId(userToken.getId()); + } + + private static String getTokenDocumentId(String id) { + return "token_" + id; } private void ensureEnabled() { @@ -335,16 +733,39 @@ public final class TokenService extends AbstractComponent { // index doesn't exist so the token is considered valid. listener.onResponse(userToken); } else { - lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> + lifecycleService.prepareIndexIfNeededThenExecute(listener::onFailure, () -> { + MultiGetRequest mGetRequest = client.prepareMultiGet() + .add(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, getInvalidatedTokenDocumentId(userToken)) + .add(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, getTokenDocumentId(userToken)) + .request(); executeAsyncWithOrigin(client.threadPool().getThreadContext(), SECURITY_ORIGIN, - client.prepareGet(SecurityLifecycleService.SECURITY_INDEX_NAME, TYPE, getDocumentId(userToken)).request(), - new ActionListener() { + mGetRequest, + new ActionListener() { @Override - public void onResponse(GetResponse response) { - if (response.isExists()) { - // this token is explicitly expired! + public void onResponse(MultiGetResponse response) { + MultiGetItemResponse[] itemResponse = response.getResponses(); + if (itemResponse[0].isFailed()) { + onFailure(itemResponse[0].getFailure().getFailure()); + } else if (itemResponse[0].getResponse().isExists()) { listener.onFailure(expiredTokenException()); + } else if (itemResponse[1].isFailed()) { + onFailure(itemResponse[1].getFailure().getFailure()); + } else if (itemResponse[1].getResponse().isExists()) { + Map source = itemResponse[1].getResponse().getSource(); + Map accessTokenSource = (Map) source.get("access_token"); + if (accessTokenSource == null) { + listener.onFailure(new IllegalStateException("token document is missing access_token field")); + } else { + Boolean invalidated = (Boolean) accessTokenSource.get("invalidated"); + if (invalidated == null) { + listener.onFailure(new IllegalStateException("token document is missing invalidated field")); + } else if (invalidated) { + listener.onFailure(expiredTokenException()); + } else { + listener.onResponse(userToken); + } + } } else { listener.onResponse(userToken); } @@ -354,7 +775,7 @@ public final class TokenService extends AbstractComponent { public void onFailure(Exception e) { // if the index or the shard is not there / available we assume that // the token is not valid - if (TransportActions.isShardNotAvailableException(e)) { + if (isShardNotAvailableException(e)) { logger.warn("failed to get token [{}] since index is not available", userToken.getId()); listener.onResponse(null); } else { @@ -362,7 +783,8 @@ public final class TokenService extends AbstractComponent { listener.onFailure(e); } } - }, client::get)); + }, client::multiGet); + }); } } @@ -371,7 +793,11 @@ public final class TokenService extends AbstractComponent { } private Instant getExpirationTime() { - return clock.instant().plusSeconds(expirationDelay.getSeconds()); + return getExpirationTime(clock.instant()); + } + + private Instant getExpirationTime(Instant now) { + return now.plusSeconds(expirationDelay.getSeconds()); } private void maybeStartTokenRemover() { @@ -387,7 +813,7 @@ public final class TokenService extends AbstractComponent { * Gets the token from the Authorization header if the header begins with * Bearer */ - String getFromHeader(ThreadContext threadContext) { + private String getFromHeader(ThreadContext threadContext) { String header = threadContext.getHeader("Authorization"); if (Strings.hasLength(header) && header.startsWith("Bearer ") && header.length() > "Bearer ".length()) { @@ -404,14 +830,17 @@ public final class TokenService extends AbstractComponent { try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES); OutputStream base64 = Base64.getEncoder().wrap(os); StreamOutput out = new OutputStreamStreamOutput(base64)) { + out.setVersion(userToken.getVersion()); KeyAndCache keyAndCache = keyCache.activeKeyCache; - Version.writeVersion(TOKEN_SERVICE_VERSION, out); + Version.writeVersion(userToken.getVersion(), out); out.writeByteArray(keyAndCache.getSalt().bytes); - out.writeByteArray(keyAndCache.getKeyHash().bytes); // TODO this requires a BWC layer in 5.6 + out.writeByteArray(keyAndCache.getKeyHash().bytes); final byte[] initializationVector = getNewInitializationVector(); out.writeByteArray(initializationVector); - try (CipherOutputStream encryptedOutput = new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache)); + try (CipherOutputStream encryptedOutput = + new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, userToken.getVersion())); StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) { + encryptedStreamOutput.setVersion(userToken.getVersion()); userToken.writeTo(encryptedStreamOutput); encryptedStreamOutput.close(); return new String(os.toByteArray(), StandardCharsets.UTF_8); @@ -424,7 +853,7 @@ public final class TokenService extends AbstractComponent { SecretKeyFactory.getInstance(KDF_ALGORITHM); } - private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache) throws GeneralSecurityException { + private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException { Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER); BytesKey salt = keyAndCache.getSalt(); try { @@ -432,7 +861,7 @@ public final class TokenService extends AbstractComponent { } catch (ExecutionException e) { throw new ElasticsearchSecurityException("Failed to compute secret key for active salt", e); } - cipher.updateAAD(currentVersionBytes); + cipher.updateAAD(ByteBuffer.allocate(4).putInt(version.id).array()); cipher.updateAAD(salt.bytes); return cipher; } @@ -466,7 +895,8 @@ public final class TokenService extends AbstractComponent { /** * Creates an {@link ElasticsearchSecurityException} that indicates the token was expired. It - * is up to the client to re-authenticate and obtain a new token + * is up to the client to re-authenticate and obtain a new token. The format for this response + * is defined in */ private static ElasticsearchSecurityException expiredTokenException() { ElasticsearchSecurityException e = @@ -477,7 +907,8 @@ public final class TokenService extends AbstractComponent { /** * Creates an {@link ElasticsearchSecurityException} that indicates the token was expired. It - * is up to the client to re-authenticate and obtain a new token + * is up to the client to re-authenticate and obtain a new token. The format for this response + * is defined in */ private static ElasticsearchSecurityException malformedTokenException() { ElasticsearchSecurityException e = @@ -486,6 +917,16 @@ public final class TokenService extends AbstractComponent { return e; } + /** + * Creates an {@link ElasticsearchSecurityException} that indicates the request contained an invalid grant + */ + private static ElasticsearchSecurityException invalidGrantException(String detail) { + ElasticsearchSecurityException e = + new ElasticsearchSecurityException("invalid_grant", RestStatus.BAD_REQUEST); + e.addHeader("error_description", detail); + return e; + } + boolean isExpiredTokenException(ElasticsearchSecurityException e) { final List headers = e.getHeader("WWW-Authenticate"); return headers != null && headers.stream().anyMatch(EXPIRED_TOKEN_WWW_AUTH_VALUE::equals); @@ -497,20 +938,13 @@ public final class TokenService extends AbstractComponent { private class KeyComputingRunnable extends AbstractRunnable { - private final StreamInput in; - private final Version version; private final BytesKey decodedSalt; - private final ActionListener listener; - private final byte[] iv; + private final ActionListener listener; private final KeyAndCache keyAndCache; - KeyComputingRunnable(StreamInput input, byte[] iv, Version version, BytesKey decodedSalt, ActionListener listener, - KeyAndCache keyAndCache) { - this.in = input; - this.version = version; + KeyComputingRunnable(BytesKey decodedSalt, ActionListener listener, KeyAndCache keyAndCache) { this.decodedSalt = decodedSalt; this.listener = listener; - this.iv = iv; this.keyAndCache = keyAndCache; } @@ -518,7 +952,7 @@ public final class TokenService extends AbstractComponent { protected void doRun() { try { final SecretKey computedKey = keyAndCache.getOrComputeKey(decodedSalt); - decryptToken(in, getDecryptionCipher(iv, computedKey, version, decodedSalt), version, listener); + listener.onResponse(computedKey); } catch (ExecutionException e) { if (e.getCause() != null && (e.getCause() instanceof GeneralSecurityException || e.getCause() instanceof IOException @@ -530,9 +964,6 @@ public final class TokenService extends AbstractComponent { } else { listener.onFailure(e); } - } catch (GeneralSecurityException | IOException e) { - logger.debug("unable to decode bearer token", e); - listener.onResponse(null); } } @@ -540,11 +971,6 @@ public final class TokenService extends AbstractComponent { public void onFailure(Exception e) { listener.onFailure(e); } - - @Override - public void onAfter() { - IOUtils.closeWhileHandlingException(in); - } } /** @@ -866,7 +1292,7 @@ public final class TokenService extends AbstractComponent { } @Override - public void close() throws IOException { + public void close() { keyAndTimestamp.key.close(); } @@ -875,12 +1301,7 @@ public final class TokenService extends AbstractComponent { } private static BytesKey calculateKeyHash(SecureString key) { - MessageDigest messageDigest = null; - try { - messageDigest = MessageDigest.getInstance("SHA-256"); - } catch (NoSuchAlgorithmException e) { - throw new AssertionError(e); - } + MessageDigest messageDigest = MessageDigests.sha256(); BytesRefBuilder b = new BytesRefBuilder(); try { b.copyChars(key); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java index ca97569da0a..0d9e0f0c896 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/authc/UserToken.java @@ -7,12 +7,17 @@ package org.elasticsearch.xpack.security.authc; import org.elasticsearch.Version; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.ToXContentObject; +import org.elasticsearch.common.xcontent.XContentBuilder; import java.io.IOException; import java.time.Instant; +import java.util.Collections; +import java.util.Map; import java.util.Objects; /** @@ -24,23 +29,28 @@ import java.util.Objects; * used by an adversary to gain access. For this reason, TLS must be enabled for these tokens to * be used. */ -public final class UserToken implements Writeable { +public final class UserToken implements Writeable, ToXContentObject { private final Version version; private final String id; private final Authentication authentication; private final Instant expirationTime; + private final Map metadata; /** * Create a new token with an autogenerated id */ UserToken(Authentication authentication, Instant expirationTime) { - this.version = Version.CURRENT; + this(Version.CURRENT, authentication, expirationTime, Collections.emptyMap()); + } + + UserToken(Version version, Authentication authentication, Instant expirationTime, Map metadata) { + this.version = version; this.id = UUIDs.base64UUID(); this.authentication = Objects.requireNonNull(authentication); this.expirationTime = Objects.requireNonNull(expirationTime); + this.metadata = metadata; } - /** * Creates a new token based on the values from the stream */ @@ -49,6 +59,11 @@ public final class UserToken implements Writeable { this.id = input.readString(); this.authentication = new Authentication(input); this.expirationTime = Instant.ofEpochSecond(input.readLong(), input.readInt()); + if (version.before(Version.V_6_2_0)) { + this.metadata = Collections.emptyMap(); + } else { + this.metadata = input.readMap(); + } } @Override @@ -57,6 +72,9 @@ public final class UserToken implements Writeable { authentication.writeTo(out); out.writeLong(expirationTime.getEpochSecond()); out.writeInt(expirationTime.getNano()); + if (out.getVersion().onOrAfter(Version.V_6_2_0)) { + out.writeMap(metadata); + } } /** @@ -76,7 +94,7 @@ public final class UserToken implements Writeable { /** * The ID of this token */ - String getId() { + public String getId() { return id; } @@ -86,4 +104,26 @@ public final class UserToken implements Writeable { Version getVersion() { return version; } + + /** + * The metadata associated with this token + */ + public Map getMetadata() { + return metadata; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field("id", id); + builder.field("expiration_time", expirationTime.toEpochMilli()); + builder.field("version", version.id); + builder.field("metadata", metadata); + try (BytesStreamOutput output = new BytesStreamOutput()) { + output.setVersion(version); + authentication.writeTo(output); + builder.field("authentication", output.bytes().toBytesRef().bytes); + } + return builder.endObject(); + } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/client/SecurityClient.java b/plugin/src/main/java/org/elasticsearch/xpack/security/client/SecurityClient.java index b2ead4e65db..c2fd00a328b 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/client/SecurityClient.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/client/SecurityClient.java @@ -46,6 +46,7 @@ import org.elasticsearch.xpack.security.action.token.InvalidateTokenAction; import org.elasticsearch.xpack.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.security.action.token.InvalidateTokenRequestBuilder; import org.elasticsearch.xpack.security.action.token.InvalidateTokenResponse; +import org.elasticsearch.xpack.security.action.token.RefreshTokenAction; import org.elasticsearch.xpack.security.action.user.ChangePasswordAction; import org.elasticsearch.xpack.security.action.user.ChangePasswordRequest; import org.elasticsearch.xpack.security.action.user.ChangePasswordRequestBuilder; @@ -270,7 +271,7 @@ public class SecurityClient { } public CreateTokenRequestBuilder prepareCreateToken() { - return new CreateTokenRequestBuilder(client); + return new CreateTokenRequestBuilder(client, CreateTokenAction.INSTANCE); } public void createToken(CreateTokenRequest request, ActionListener listener) { @@ -284,4 +285,14 @@ public class SecurityClient { public void invalidateToken(InvalidateTokenRequest request, ActionListener listener) { client.execute(InvalidateTokenAction.INSTANCE, request, listener); } + + public CreateTokenRequestBuilder prepareRefreshToken(String refreshToken) { + return new CreateTokenRequestBuilder(client, RefreshTokenAction.INSTANCE) + .setRefreshToken(refreshToken) + .setGrantType("refresh_token"); + } + + public void refreshToken(CreateTokenRequest request, ActionListener listener) { + client.execute(RefreshTokenAction.INSTANCE, request, listener); + } } diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenAction.java b/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenAction.java index 03e18c8dfe2..fb4d6b6be18 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenAction.java @@ -6,6 +6,8 @@ package org.elasticsearch.xpack.security.rest.action.oauth2; import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchSecurityException; +import org.elasticsearch.action.Action; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.client.node.NodeClient; @@ -24,7 +26,9 @@ import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.security.action.token.CreateTokenAction; import org.elasticsearch.xpack.security.action.token.CreateTokenRequest; +import org.elasticsearch.xpack.security.action.token.CreateTokenRequestBuilder; import org.elasticsearch.xpack.security.action.token.CreateTokenResponse; +import org.elasticsearch.xpack.security.action.token.RefreshTokenAction; import org.elasticsearch.xpack.security.rest.action.SecurityBaseRestHandler; import java.io.IOException; @@ -43,7 +47,7 @@ import static org.elasticsearch.rest.RestRequest.Method.POST; public final class RestGetTokenAction extends SecurityBaseRestHandler { static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>("token_request", - a -> new CreateTokenRequest((String) a[0], (String) a[1], (SecureString) a[2], (String) a[3])); + a -> new CreateTokenRequest((String) a[0], (String) a[1], (SecureString) a[2], (String) a[3], (String) a[4])); static { PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("grant_type")); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("username")); @@ -51,6 +55,7 @@ public final class RestGetTokenAction extends SecurityBaseRestHandler { Arrays.copyOfRange(parser.textCharacters(), parser.textOffset(), parser.textOffset() + parser.textLength())), new ParseField("password"), ValueType.STRING); PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("scope")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refresh_token")); } public RestGetTokenAction(Settings settings, RestController controller, XPackLicenseState xPackLicenseState) { @@ -67,7 +72,9 @@ public final class RestGetTokenAction extends SecurityBaseRestHandler { protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client)throws IOException { try (XContentParser parser = request.contentParser()) { final CreateTokenRequest tokenRequest = PARSER.parse(parser, null); - return channel -> client.execute(CreateTokenAction.INSTANCE, tokenRequest, + final Action action = + "refresh_token".equals(tokenRequest.getGrantType()) ? RefreshTokenAction.INSTANCE : CreateTokenAction.INSTANCE; + return channel -> client.execute(action, tokenRequest, // this doesn't use the RestBuilderListener since we need to override the // handling of failures in some cases. new CreateTokenResponseActionListener(channel, request, logger)); @@ -100,32 +107,37 @@ public final class RestGetTokenAction extends SecurityBaseRestHandler { public void onFailure(Exception e) { if (e instanceof ActionRequestValidationException) { ActionRequestValidationException validationException = (ActionRequestValidationException) e; - try (XContentBuilder builder = channel.newErrorBuilder()) { - final TokenRequestError error; - if (validationException.validationErrors().stream().anyMatch(s -> s.contains("grant_type"))) { - error = TokenRequestError.UNSUPPORTED_GRANT_TYPE; - } else { - error = TokenRequestError.INVALID_REQUEST; - } - - // defined by https://tools.ietf.org/html/rfc6749#section-5.2 - builder.startObject() - .field("error", - error.toString().toLowerCase(Locale.ROOT)) - .field("error_description", - validationException.getMessage()) - .endObject(); - channel.sendResponse( - new BytesRestResponse(RestStatus.BAD_REQUEST, builder)); - } catch (IOException ioe) { - ioe.addSuppressed(e); - sendFailure(ioe); + final TokenRequestError error; + if (validationException.validationErrors().stream().anyMatch(s -> s.contains("grant_type"))) { + error = TokenRequestError.UNSUPPORTED_GRANT_TYPE; + } else { + error = TokenRequestError.INVALID_REQUEST; } + + sendTokenErrorResponse(error, validationException.getMessage(), e); + } else if (e instanceof ElasticsearchSecurityException && "invalid_grant".equals(e.getMessage()) && + ((ElasticsearchSecurityException) e).getHeader("error_description").size() == 1) { + sendTokenErrorResponse(TokenRequestError.INVALID_GRANT, + ((ElasticsearchSecurityException) e).getHeader("error_description").get(0), e); } else { sendFailure(e); } } + void sendTokenErrorResponse(TokenRequestError error, String description, Exception e) { + try (XContentBuilder builder = channel.newErrorBuilder()) { + // defined by https://tools.ietf.org/html/rfc6749#section-5.2 + builder.startObject() + .field("error", error.toString().toLowerCase(Locale.ROOT)) + .field("error_description", description) + .endObject(); + channel.sendResponse(new BytesRestResponse(RestStatus.BAD_REQUEST, builder)); + } catch (IOException ioe) { + ioe.addSuppressed(e); + sendFailure(e); + } + } + void sendFailure(Exception e) { try { channel.sendResponse(new BytesRestResponse(channel, e)); diff --git a/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java b/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java index b4d29736cfd..bd8caecf375 100644 --- a/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java +++ b/plugin/src/main/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestInvalidateTokenAction.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.security.rest.action.oauth2; import org.elasticsearch.client.node.NodeClient; import org.elasticsearch.common.ParseField; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.ConstructingObjectParser; import org.elasticsearch.common.xcontent.XContentBuilder; @@ -32,10 +34,11 @@ import static org.elasticsearch.rest.RestRequest.Method.DELETE; */ public final class RestInvalidateTokenAction extends SecurityBaseRestHandler { - static final ConstructingObjectParser PARSER = - new ConstructingObjectParser<>("invalidate_token", a -> ((String) a[0])); + static final ConstructingObjectParser, Void> PARSER = + new ConstructingObjectParser<>("invalidate_token", a -> new Tuple<>((String) a[0], (String) a[1])); static { - PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField("token")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("token")); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField("refresh_token")); } public RestInvalidateTokenAction(Settings settings, RestController controller, XPackLicenseState xPackLicenseState) { @@ -51,8 +54,26 @@ public final class RestInvalidateTokenAction extends SecurityBaseRestHandler { @Override protected RestChannelConsumer innerPrepareRequest(RestRequest request, NodeClient client) throws IOException { try (XContentParser parser = request.contentParser()) { - final String token = PARSER.parse(parser, null); - final InvalidateTokenRequest tokenRequest = new InvalidateTokenRequest(token); + final Tuple tuple = PARSER.parse(parser, null); + final String token = tuple.v1(); + final String refreshToken = tuple.v2(); + + final String tokenString; + final InvalidateTokenRequest.Type type; + if (Strings.hasLength(token) && Strings.hasLength(refreshToken)) { + throw new IllegalArgumentException("only one of [token, refresh_token] may be sent per request"); + } else if (Strings.hasLength(token)) { + tokenString = token; + type = InvalidateTokenRequest.Type.ACCESS_TOKEN; + } else if (Strings.hasLength(refreshToken)) { + tokenString = refreshToken; + type = InvalidateTokenRequest.Type.REFRESH_TOKEN; + } else { + tokenString = null; + type = null; + } + + final InvalidateTokenRequest tokenRequest = new InvalidateTokenRequest(tokenString, type); return channel -> client.execute(InvalidateTokenAction.INSTANCE, tokenRequest, new RestBuilderListener(channel) { @Override diff --git a/plugin/src/main/resources/security-index-template.json b/plugin/src/main/resources/security-index-template.json index d11f07750d7..8197a825079 100644 --- a/plugin/src/main/resources/security-index-template.json +++ b/plugin/src/main/resources/security-index-template.json @@ -97,19 +97,81 @@ "run_as" : { "type" : "keyword" }, - "doc_type": { + "doc_type" : { "type" : "keyword" }, - "type": { + "type" : { "type" : "keyword" }, - "expiration_time": { - "type": "date", - "format": "epoch_millis" + "expiration_time" : { + "type" : "date", + "format" : "epoch_millis" }, - "rules": { + "creation_time" : { + "type" : "date", + "format" : "epoch_millis" + }, + "rules" : { "type" : "object", "dynamic" : true + }, + "refresh_token" : { + "type" : "object", + "properties" : { + "token" : { + "type" : "keyword" + }, + "refreshed" : { + "type" : "boolean" + }, + "invalidated" : { + "type" : "boolean" + }, + "client" : { + "type" : "object", + "properties" : { + "type" : { + "type" : "keyword" + }, + "user" : { + "type" : "keyword" + }, + "realm" : { + "type" : "keyword" + } + } + } + } + }, + "access_token" : { + "type" : "object", + "properties" : { + "user_token" : { + "type" : "object", + "properties" : { + "id" : { + "type" : "keyword" + }, + "expiration_time" : { + "type" : "date", + "format" : "epoch_millis" + }, + "version" : { + "type" : "integer" + }, + "metadata" : { + "type" : "object", + "dynamic" : true + }, + "authentication" : { + "type" : "binary" + } + } + }, + "invalidated" : { + "type" : "boolean" + } + } } } } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestTests.java index d44bcf34bf2..4445c417ae7 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/action/token/CreateTokenRequestTests.java @@ -9,15 +9,18 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.test.ESTestCase; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasItem; public class CreateTokenRequestTests extends ESTestCase { - public void testRequestValidation() throws Exception { + public void testRequestValidation() { CreateTokenRequest request = new CreateTokenRequest(); ActionRequestValidationException ve = request.validate(); assertNotNull(ve); - assertEquals(3, ve.validationErrors().size()); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), containsString("[password, refresh_token]")); + assertThat(ve.validationErrors().get(0), containsString("grant_type")); request.setGrantType("password"); ve = request.validate(); @@ -44,5 +47,29 @@ public class CreateTokenRequestTests extends ESTestCase { request.setPassword(new SecureString(randomAlphaOfLengthBetween(1, 256).toCharArray())); ve = request.validate(); assertNull(ve); + + request.setRefreshToken(randomAlphaOfLengthBetween(1, 10)); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors().get(0), containsString("refresh_token is not supported")); + + request.setGrantType("refresh_token"); + ve = request.validate(); + assertNotNull(ve); + assertEquals(2, ve.validationErrors().size()); + assertThat(ve.validationErrors(), hasItem(containsString("username is not supported"))); + assertThat(ve.validationErrors(), hasItem(containsString("password is not supported"))); + + request.setUsername(null); + request.setPassword(null); + ve = request.validate(); + assertNull(ve); + + request.setRefreshToken(null); + ve = request.validate(); + assertNotNull(ve); + assertEquals(1, ve.validationErrors().size()); + assertThat(ve.validationErrors(), hasItem("refresh_token is missing")); } } diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java index d82e74b47ae..44f5a02c5bd 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/AuthenticationServiceTests.java @@ -10,16 +10,26 @@ import org.apache.lucene.util.SetOnce; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.get.GetRequest; -import org.elasticsearch.action.get.GetRequestBuilder; import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.get.MultiGetAction; +import org.elasticsearch.action.get.MultiGetItemResponse; +import org.elasticsearch.action.get.MultiGetRequest; +import org.elasticsearch.action.get.MultiGetRequestBuilder; +import org.elasticsearch.action.get.MultiGetResponse; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.update.UpdateAction; +import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.SuppressForbidden; +import org.elasticsearch.common.collect.MapBuilder; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; @@ -30,6 +40,7 @@ import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.license.XPackLicenseState; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.threadpool.FixedExecutorBuilder; @@ -141,9 +152,22 @@ public class AuthenticationServiceTests extends ESTestCase { threadContext = threadPool.getThreadContext(); when(client.threadPool()).thenReturn(threadPool); when(client.settings()).thenReturn(settings); + when(client.prepareIndex(any(String.class), any(String.class), any(String.class))) + .thenReturn(new IndexRequestBuilder(client, IndexAction.INSTANCE)); + when(client.prepareUpdate(any(String.class), any(String.class), any(String.class))) + .thenReturn(new UpdateRequestBuilder(client, UpdateAction.INSTANCE)); + doAnswer(invocationOnMock -> { + ActionListener responseActionListener = (ActionListener) invocationOnMock.getArguments()[2]; + responseActionListener.onResponse(new IndexResponse()); + return null; + }).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class)); lifecycleService = mock(SecurityLifecycleService.class); - ClusterService clusterService = new ClusterService(settings, new ClusterSettings(settings, ClusterSettings - .BUILT_IN_CLUSTER_SETTINGS), threadPool, Collections.emptyMap()); + doAnswer(invocationOnMock -> { + Runnable runnable = (Runnable) invocationOnMock.getArguments()[1]; + runnable.run(); + return null; + }).when(lifecycleService).prepareIndexIfNeededThenExecute(any(Consumer.class), any(Runnable.class)); + ClusterService clusterService = ClusterServiceUtils.createClusterService(threadPool); tokenService = new TokenService(settings, Clock.systemUTC(), client, lifecycleService, clusterService); service = new AuthenticationService(settings, realms, auditTrail, new DefaultAuthenticationFailureHandler(), threadPool, new AnonymousUser(settings), tokenService); @@ -806,7 +830,12 @@ public class AuthenticationServiceTests extends ESTestCase { User user = new User("_username", "r1"); final AtomicBoolean completed = new AtomicBoolean(false); final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null); - String token = tokenService.getUserTokenString(tokenService.createUserToken(expected)); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); + tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap()); + } + String token = tokenService.getUserTokenString(tokenFuture.get().v1()); try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { threadContext.putHeader("Authorization", "Bearer " + token); service.authenticate("_action", message, null, ActionListener.wrap(result -> { @@ -863,20 +892,40 @@ public class AuthenticationServiceTests extends ESTestCase { } public void testExpiredToken() throws Exception { + when(lifecycleService.isSecurityIndexAvailable()).thenReturn(true); + when(lifecycleService.isSecurityIndexExisting()).thenReturn(true); User user = new User("_username", "r1"); final Authentication expected = new Authentication(user, new RealmRef("realm", "custom", "node"), null); - String token = tokenService.getUserTokenString(tokenService.createUserToken(expected)); - when(lifecycleService.isSecurityIndexExisting()).thenReturn(true); - GetRequestBuilder getRequestBuilder = mock(GetRequestBuilder.class); - when(client.prepareGet(eq(SecurityLifecycleService.SECURITY_INDEX_NAME), eq("doc"), any(String.class))) - .thenReturn(getRequestBuilder); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + try (ThreadContext.StoredContext ctx = threadContext.stashContext()) { + Authentication originatingAuth = new Authentication(new User("creator"), new RealmRef("test", "test", "test"), null); + tokenService.createUserToken(expected, originatingAuth, tokenFuture, Collections.emptyMap()); + } + String token = tokenService.getUserTokenString(tokenFuture.get().v1()); + when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + MultiGetResponse response = mock(MultiGetResponse.class); + MultiGetItemResponse[] responses = new MultiGetItemResponse[2]; + when(response.getResponses()).thenReturn(responses); + + final boolean newExpired = randomBoolean(); + GetResponse oldGetResponse = mock(GetResponse.class); + when(oldGetResponse.isExists()).thenReturn(newExpired == false); + responses[0] = new MultiGetItemResponse(oldGetResponse, null); + + GetResponse getResponse = mock(GetResponse.class); + responses[1] = new MultiGetItemResponse(getResponse, null); + when(getResponse.isExists()).thenReturn(newExpired); + if (newExpired) { + Map source = MapBuilder.newMapBuilder() + .put("access_token", Collections.singletonMap("invalidated", true)) + .immutableMap(); + when(getResponse.getSource()).thenReturn(source); + } listener.onResponse(response); return Void.TYPE; - }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + }).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class)); doAnswer(invocationOnMock -> { ((Runnable) invocationOnMock.getArguments()[1]).run(); diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java index 8d4b1647ed2..d9206a400bd 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenAuthIntegTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.SecurityIntegTestCase; import org.elasticsearch.test.SecuritySettingsSource; @@ -23,14 +24,16 @@ import org.elasticsearch.test.junit.annotations.TestLogging; import org.elasticsearch.xpack.XPackSettings; import org.elasticsearch.xpack.security.SecurityLifecycleService; import org.elasticsearch.xpack.security.action.token.CreateTokenResponse; +import org.elasticsearch.xpack.security.action.token.InvalidateTokenRequest; import org.elasticsearch.xpack.security.action.token.InvalidateTokenResponse; +import org.elasticsearch.xpack.security.action.user.AuthenticateAction; +import org.elasticsearch.xpack.security.action.user.AuthenticateRequest; +import org.elasticsearch.xpack.security.action.user.AuthenticateResponse; import org.elasticsearch.xpack.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.security.client.SecurityClient; import org.junit.After; import org.junit.Before; -import java.io.IOException; -import java.io.UncheckedIOException; import java.time.Instant; import java.time.temporal.ChronoUnit; import java.util.Collections; @@ -38,8 +41,8 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; +import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertNoTimeout; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.greaterThan; public class TokenAuthIntegTests extends SecurityIntegTestCase { @@ -49,7 +52,7 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { .put(super.nodeSettings(nodeOrdinal)) // crank up the deletion interval and set timeout for delete requests .put(TokenService.DELETE_INTERVAL.getKey(), TimeValue.timeValueSeconds(1L)) - .put(TokenService.DELETE_TIMEOUT.getKey(), TimeValue.timeValueSeconds(2L)) + .put(TokenService.DELETE_TIMEOUT.getKey(), TimeValue.timeValueSeconds(5L)) .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), true) .build(); } @@ -134,12 +137,16 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { Instant created = Instant.now(); - InvalidateTokenResponse invalidateResponse = securityClient.prepareInvalidateToken(response.getTokenString()).get(); + InvalidateTokenResponse invalidateResponse = securityClient + .prepareInvalidateToken(response.getTokenString()) + .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) + .get(); assertTrue(invalidateResponse.isCreated()); AtomicReference docId = new AtomicReference<>(); assertBusy(() -> { SearchResponse searchResponse = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME) - .setSource(SearchSourceBuilder.searchSource().query(QueryBuilders.termQuery("doc_type", TokenService.DOC_TYPE))) + .setSource(SearchSourceBuilder.searchSource() + .query(QueryBuilders.termQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE))) .setSize(1) .setTerminateAfter(1) .get(); @@ -157,18 +164,21 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { AtomicBoolean deleteTriggered = new AtomicBoolean(false); assertBusy(() -> { - assertTrue(Instant.now().isAfter(created.plusSeconds(1L).plusMillis(500L))); if (deleteTriggered.compareAndSet(false, true)) { // invalidate a invalid token... doesn't matter that it is bad... we just want this action to trigger the deletion try { - securityClient.prepareInvalidateToken("fooobar").execute().actionGet(); + securityClient.prepareInvalidateToken("fooobar") + .setType(randomFrom(InvalidateTokenRequest.Type.values())) + .execute() + .actionGet(); } catch (ElasticsearchSecurityException e) { assertEquals("token malformed", e.getMessage()); } } client.admin().indices().prepareRefresh(SecurityLifecycleService.SECURITY_INDEX_NAME).get(); SearchResponse searchResponse = client.prepareSearch(SecurityLifecycleService.SECURITY_INDEX_NAME) - .setSource(SearchSourceBuilder.searchSource().query(QueryBuilders.termQuery("doc_type", TokenService.DOC_TYPE))) + .setSource(SearchSourceBuilder.searchSource() + .query(QueryBuilders.termQuery("doc_type", TokenService.INVALIDATED_TOKEN_DOC_TYPE))) .setSize(0) .setTerminateAfter(1) .get(); @@ -176,30 +186,156 @@ public class TokenAuthIntegTests extends SecurityIntegTestCase { }, 30, TimeUnit.SECONDS); } - public void testExpireMultipleTimes() throws Exception { + public void testExpireMultipleTimes() { CreateTokenResponse response = securityClient().prepareCreateToken() .setGrantType("password") .setUsername(SecuritySettingsSource.TEST_USER_NAME) .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) .get(); - InvalidateTokenResponse invalidateResponse = securityClient().prepareInvalidateToken(response.getTokenString()).get(); - - // if the token is expired then the API will return false for created so we need to handle that - final boolean correctResponse = invalidateResponse.isCreated() || isTokenExpired(response.getTokenString()); - assertTrue(correctResponse); - assertFalse(securityClient().prepareInvalidateToken(response.getTokenString()).get().isCreated()); + InvalidateTokenResponse invalidateResponse = securityClient() + .prepareInvalidateToken(response.getTokenString()) + .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) + .get(); + assertTrue(invalidateResponse.isCreated()); + assertFalse(securityClient() + .prepareInvalidateToken(response.getTokenString()) + .setType(InvalidateTokenRequest.Type.ACCESS_TOKEN) + .get() + .isCreated()); } - private static boolean isTokenExpired(String token) { - try { - TokenService tokenService = internalCluster().getInstance(TokenService.class); - PlainActionFuture tokenFuture = new PlainActionFuture<>(); - tokenService.decodeToken(token, tokenFuture); - return tokenFuture.actionGet().getExpirationTime().isBefore(Instant.now()); - } catch (IOException e) { - throw new UncheckedIOException(e); - } + public void testRefreshingToken() { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + // get cluster health with token + assertNoTimeout(client() + .filterWithHeader(Collections.singletonMap("Authorization", "Bearer " + createTokenResponse.getTokenString())) + .admin().cluster().prepareHealth().get()); + + CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); + assertNotNull(refreshResponse.getRefreshToken()); + assertNotEquals(refreshResponse.getRefreshToken(), createTokenResponse.getRefreshToken()); + assertNotEquals(refreshResponse.getTokenString(), createTokenResponse.getTokenString()); + + assertNoTimeout(client().filterWithHeader(Collections.singletonMap("Authorization", "Bearer " + refreshResponse.getTokenString())) + .admin().cluster().prepareHealth().get()); + } + + public void testRefreshingInvalidatedToken() { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + InvalidateTokenResponse invalidateResponse = securityClient + .prepareInvalidateToken(createTokenResponse.getRefreshToken()) + .setType(InvalidateTokenRequest.Type.REFRESH_TOKEN) + .get(); + assertTrue(invalidateResponse.isCreated()); + + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, + () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); + assertEquals("invalid_grant", e.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, e.status()); + assertEquals("token has been invalidated", e.getHeader("error_description").get(0)); + } + + public void testRefreshingMultipleTimes() { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); + assertNotNull(refreshResponse); + + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, + () -> securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); + assertEquals("invalid_grant", e.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, e.status()); + assertEquals("token has already been refreshed", e.getHeader("error_description").get(0)); + } + + public void testRefreshAsDifferentUser() { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_USER_NAME, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + + ElasticsearchSecurityException e = expectThrows(ElasticsearchSecurityException.class, + () -> new SecurityClient(client() + .filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING)))) + .prepareRefreshToken(createTokenResponse.getRefreshToken()).get()); + assertEquals("invalid_grant", e.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, e.status()); + assertEquals("tokens must be refreshed by the creating client", e.getHeader("error_description").get(0)); + } + + public void testCreateThenRefreshAsDifferentUser() { + Client client = client().filterWithHeader(Collections.singletonMap("Authorization", + UsernamePasswordToken.basicAuthHeaderValue(SecuritySettingsSource.TEST_SUPERUSER, + SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING))); + SecurityClient securityClient = new SecurityClient(client); + CreateTokenResponse createTokenResponse = securityClient.prepareCreateToken() + .setGrantType("password") + .setUsername(SecuritySettingsSource.TEST_USER_NAME) + .setPassword(new SecureString(SecuritySettingsSource.TEST_PASSWORD.toCharArray())) + .get(); + assertNotNull(createTokenResponse.getRefreshToken()); + + CreateTokenResponse refreshResponse = securityClient.prepareRefreshToken(createTokenResponse.getRefreshToken()).get(); + assertNotEquals(refreshResponse.getTokenString(), createTokenResponse.getTokenString()); + assertNotEquals(refreshResponse.getRefreshToken(), createTokenResponse.getRefreshToken()); + + PlainActionFuture authFuture = new PlainActionFuture<>(); + AuthenticateRequest request = new AuthenticateRequest(); + request.username(SecuritySettingsSource.TEST_SUPERUSER); + client.execute(AuthenticateAction.INSTANCE, request, authFuture); + AuthenticateResponse response = authFuture.actionGet(); + assertEquals(SecuritySettingsSource.TEST_SUPERUSER, response.user().principal()); + + authFuture = new PlainActionFuture<>(); + request = new AuthenticateRequest(); + request.username(SecuritySettingsSource.TEST_USER_NAME); + client.filterWithHeader(Collections.singletonMap("Authorization", "Bearer " + createTokenResponse.getTokenString())) + .execute(AuthenticateAction.INSTANCE, request, authFuture); + response = authFuture.actionGet(); + assertEquals(SecuritySettingsSource.TEST_USER_NAME, response.user().principal()); + + authFuture = new PlainActionFuture<>(); + request = new AuthenticateRequest(); + request.username(SecuritySettingsSource.TEST_USER_NAME); + client.filterWithHeader(Collections.singletonMap("Authorization", "Bearer " + refreshResponse.getTokenString())) + .execute(AuthenticateAction.INSTANCE, request, authFuture); + response = authFuture.actionGet(); + assertEquals(SecuritySettingsSource.TEST_USER_NAME, response.user().principal()); } @Before diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java index 7a83a985de8..4246faf22d0 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/authc/TokenServiceTests.java @@ -9,19 +9,31 @@ import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.NoShardAvailableActionException; import org.elasticsearch.action.get.GetAction; -import org.elasticsearch.action.get.GetRequest; import org.elasticsearch.action.get.GetRequestBuilder; import org.elasticsearch.action.get.GetResponse; +import org.elasticsearch.action.get.MultiGetAction; +import org.elasticsearch.action.get.MultiGetItemResponse; +import org.elasticsearch.action.get.MultiGetRequest; +import org.elasticsearch.action.get.MultiGetRequestBuilder; +import org.elasticsearch.action.get.MultiGetResponse; +import org.elasticsearch.action.index.IndexAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.action.index.IndexResponse; import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.action.update.UpdateAction; +import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.collect.MapBuilder; +import org.elasticsearch.common.collect.Tuple; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.index.Index; import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.node.Node; +import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.EqualsHashCodeTestUtils; import org.elasticsearch.threadpool.FixedExecutorBuilder; @@ -37,10 +49,11 @@ import org.junit.Before; import org.junit.BeforeClass; import javax.crypto.SecretKey; -import java.security.GeneralSecurityException; +import java.io.IOException; import java.time.Clock; import java.util.Base64; import java.util.Collections; +import java.util.Map; import java.util.function.Consumer; import static java.time.Clock.systemUTC; @@ -71,27 +84,48 @@ public class TokenServiceTests extends ESTestCase { client = mock(Client.class); when(client.threadPool()).thenReturn(threadPool); when(client.settings()).thenReturn(settings); - lifecycleService = mock(SecurityLifecycleService.class); + when(client.prepareMultiGet()).thenReturn(new MultiGetRequestBuilder(client, MultiGetAction.INSTANCE)); doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[2]; - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + MultiGetResponse response = mock(MultiGetResponse.class); + MultiGetItemResponse[] responses = new MultiGetItemResponse[2]; + when(response.getResponses()).thenReturn(responses); + + GetResponse oldGetResponse = mock(GetResponse.class); + when(oldGetResponse.isExists()).thenReturn(false); + responses[0] = new MultiGetItemResponse(oldGetResponse, null); + + GetResponse getResponse = mock(GetResponse.class); + responses[1] = new MultiGetItemResponse(getResponse, null); + when(getResponse.isExists()).thenReturn(false); listener.onResponse(response); return Void.TYPE; - }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + }).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class)); + when(client.prepareIndex(any(String.class), any(String.class), any(String.class))) + .thenReturn(new IndexRequestBuilder(client, IndexAction.INSTANCE)); + when(client.prepareUpdate(any(String.class), any(String.class), any(String.class))) + .thenReturn(new UpdateRequestBuilder(client, UpdateAction.INSTANCE)); doAnswer(invocationOnMock -> { - ((Runnable) invocationOnMock.getArguments()[1]).run(); + ActionListener responseActionListener = (ActionListener) invocationOnMock.getArguments()[2]; + responseActionListener.onResponse(new IndexResponse()); + return null; + }).when(client).execute(eq(IndexAction.INSTANCE), any(IndexRequest.class), any(ActionListener.class)); + + // setup lifecycle service + lifecycleService = mock(SecurityLifecycleService.class); + doAnswer(invocationOnMock -> { + Runnable runnable = (Runnable) invocationOnMock.getArguments()[1]; + runnable.run(); return null; }).when(lifecycleService).prepareIndexIfNeededThenExecute(any(Consumer.class), any(Runnable.class)); - when(client.threadPool()).thenReturn(threadPool); - this.clusterService = new ClusterService(settings, new ClusterSettings(settings, ClusterSettings - .BUILT_IN_CLUSTER_SETTINGS), threadPool, Collections.emptyMap()); + this.clusterService = ClusterServiceUtils.createClusterService(threadPool); } @BeforeClass - public static void startThreadPool() { + public static void startThreadPool() throws IOException { threadPool = new ThreadPool(settings, new FixedExecutorBuilder(settings, TokenService.THREAD_POOL_NAME, 1, 1000, "xpack.security.authc.token.thread_pool")); + new Authentication(new User("foo"), new RealmRef("realm", "type", "node"), null).writeToContext(threadPool.getThreadContext()); } @AfterClass @@ -103,7 +137,9 @@ public class TokenServiceTests extends ESTestCase { public void testAttachAndGetToken() throws Exception { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -131,7 +167,9 @@ public class TokenServiceTests extends ESTestCase { public void testRotateKey() throws Exception { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -152,7 +190,9 @@ public class TokenServiceTests extends ESTestCase { assertEquals(authentication, serialized.getAuthentication()); } - final UserToken newToken = tokenService.createUserToken(authentication); + PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap()); + final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); @@ -184,7 +224,9 @@ public class TokenServiceTests extends ESTestCase { clusterService); otherTokenService.refreshMetaData(tokenService.getTokenMetaData()); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -211,7 +253,9 @@ public class TokenServiceTests extends ESTestCase { public void testPruneKeys() throws Exception { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -238,7 +282,9 @@ public class TokenServiceTests extends ESTestCase { assertEquals(authentication, serialized.getAuthentication()); } - final UserToken newToken = tokenService.createUserToken(authentication); + PlainActionFuture> newTokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap()); + final UserToken newToken = newTokenFuture.get().v1(); assertNotNull(newToken); assertNotEquals(tokenService.getUserTokenString(newToken), tokenService.getUserTokenString(token)); @@ -267,7 +313,9 @@ public class TokenServiceTests extends ESTestCase { public void testPassphraseWorks() throws Exception { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -294,7 +342,9 @@ public class TokenServiceTests extends ESTestCase { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + UserToken token = tokenFuture.get().v1(); assertThat(tokenService.getUserTokenString(token), notNullValue()); tokenService.clearActiveKeyCache(); @@ -306,18 +356,33 @@ public class TokenServiceTests extends ESTestCase { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); - GetRequestBuilder getRequestBuilder = mock(GetRequestBuilder.class); - when(client.prepareGet(SecurityLifecycleService.SECURITY_INDEX_NAME, "doc", TokenService.DOC_TYPE + "_" + token.getId())) - .thenReturn(getRequestBuilder); doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(true); + ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + MultiGetResponse response = mock(MultiGetResponse.class); + MultiGetItemResponse[] responses = new MultiGetItemResponse[2]; + when(response.getResponses()).thenReturn(responses); + + final boolean newExpired = randomBoolean(); + GetResponse oldGetResponse = mock(GetResponse.class); + when(oldGetResponse.isExists()).thenReturn(newExpired == false); + responses[0] = new MultiGetItemResponse(oldGetResponse, null); + + GetResponse getResponse = mock(GetResponse.class); + responses[1] = new MultiGetItemResponse(getResponse, null); + when(getResponse.isExists()).thenReturn(newExpired); + if (newExpired) { + Map source = MapBuilder.newMapBuilder() + .put("access_token", Collections.singletonMap("invalidated", true)) + .immutableMap(); + when(getResponse.getSource()).thenReturn(source); + } listener.onResponse(response); return Void.TYPE; - }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + }).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class)); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); @@ -344,7 +409,9 @@ public class TokenServiceTests extends ESTestCase { ClockMock clock = ClockMock.frozen(); TokenService tokenService = new TokenService(tokenServiceEnabledSettings, clock, client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); requestContext.putHeader("Authorization", "Bearer " + tokenService.getUserTokenString(token)); @@ -390,8 +457,8 @@ public class TokenServiceTests extends ESTestCase { TokenService tokenService = new TokenService(Settings.builder() .put(XPackSettings.TOKEN_SERVICE_ENABLED_SETTING.getKey(), false) .build(), - systemUTC(), client, lifecycleService, clusterService); - IllegalStateException e = expectThrows(IllegalStateException.class, () -> tokenService.createUserToken(null)); + Clock.systemUTC(), client, lifecycleService, clusterService); + IllegalStateException e = expectThrows(IllegalStateException.class, () -> tokenService.createUserToken(null, null, null, null)); assertEquals("tokens are not enabled", e.getMessage()); PlainActionFuture future = new PlainActionFuture<>(); @@ -400,7 +467,7 @@ public class TokenServiceTests extends ESTestCase { e = expectThrows(IllegalStateException.class, () -> { PlainActionFuture invalidateFuture = new PlainActionFuture<>(); - tokenService.invalidateToken(null, invalidateFuture); + tokenService.invalidateAccessToken(null, invalidateFuture); invalidateFuture.actionGet(); }); assertEquals("tokens are not enabled", e.getMessage()); @@ -448,7 +515,9 @@ public class TokenServiceTests extends ESTestCase { TokenService tokenService = new TokenService(tokenServiceEnabledSettings, systemUTC(), client, lifecycleService, clusterService); Authentication authentication = new Authentication(new User("joe", "admin"), new RealmRef("native_realm", "native", "node1"), null); - final UserToken token = tokenService.createUserToken(authentication); + PlainActionFuture> tokenFuture = new PlainActionFuture<>(); + tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap()); + final UserToken token = tokenFuture.get().v1(); assertNotNull(token); ThreadContext requestContext = new ThreadContext(Settings.EMPTY); @@ -458,7 +527,7 @@ public class TokenServiceTests extends ESTestCase { ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; listener.onFailure(new NoShardAvailableActionException(new ShardId(new Index("foo", "uuid"), 0), "shard oh shard")); return Void.TYPE; - }).when(client).get(any(GetRequest.class), any(ActionListener.class)); + }).when(client).multiGet(any(MultiGetRequest.class), any(ActionListener.class)); when(client.prepareGet(anyString(), anyString(), anyString())).thenReturn(new GetRequestBuilder(client, GetAction.INSTANCE)); try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) { diff --git a/plugin/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenActionTests.java b/plugin/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenActionTests.java index 9d752e93118..4c181862a47 100644 --- a/plugin/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenActionTests.java +++ b/plugin/src/test/java/org/elasticsearch/xpack/security/rest/action/oauth2/RestGetTokenActionTests.java @@ -42,7 +42,7 @@ public class RestGetTokenActionTests extends ESTestCase { }; CreateTokenResponseActionListener listener = new CreateTokenResponseActionListener(restChannel, restRequest, NoOpLogger.INSTANCE); - ActionRequestValidationException ve = new CreateTokenRequest(null, null, null, null).validate(); + ActionRequestValidationException ve = new CreateTokenRequest(null, null, null, null, null).validate(); listener.onFailure(ve); RestResponse response = responseSetOnce.get(); assertNotNull(response); @@ -66,7 +66,7 @@ public class RestGetTokenActionTests extends ESTestCase { }; CreateTokenResponseActionListener listener = new CreateTokenResponseActionListener(restChannel, restRequest, NoOpLogger.INSTANCE); CreateTokenResponse createTokenResponse = - new CreateTokenResponse(randomAlphaOfLengthBetween(1, 256), TimeValue.timeValueHours(1L), null); + new CreateTokenResponse(randomAlphaOfLengthBetween(1, 256), TimeValue.timeValueHours(1L), null, randomAlphaOfLength(4)); listener.onResponse(createTokenResponse); RestResponse response = responseSetOnce.get(); @@ -78,7 +78,8 @@ public class RestGetTokenActionTests extends ESTestCase { assertThat(map, hasEntry("type", "Bearer")); assertThat(map, hasEntry("access_token", createTokenResponse.getTokenString())); assertThat(map, hasEntry("expires_in", Math.toIntExact(createTokenResponse.getExpiresIn().seconds()))); - assertEquals(3, map.size()); + assertThat(map, hasEntry("refresh_token", createTokenResponse.getRefreshToken())); + assertEquals(4, map.size()); } public void testParser() throws Exception { @@ -96,4 +97,21 @@ public class RestGetTokenActionTests extends ESTestCase { assertTrue(SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING.equals(createTokenRequest.getPassword())); } } + + public void testParserRefreshRequest() throws Exception { + final String token = randomAlphaOfLengthBetween(4, 32); + final String request = "{" + + "\"grant_type\": \"refresh_token\"," + + "\"refresh_token\": \"" + token + "\"," + + "\"scope\": \"FULL\"" + + "}"; + try (XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, request)) { + CreateTokenRequest createTokenRequest = RestGetTokenAction.PARSER.parse(parser, null); + assertEquals("refresh_token", createTokenRequest.getGrantType()); + assertEquals(token, createTokenRequest.getRefreshToken()); + assertEquals("FULL", createTokenRequest.getScope()); + assertNull(createTokenRequest.getUsername()); + assertNull(createTokenRequest.getPassword()); + } + } } diff --git a/qa/rolling-upgrade/build.gradle b/qa/rolling-upgrade/build.gradle index 3c3ead28a0c..8d85f892724 100644 --- a/qa/rolling-upgrade/build.gradle +++ b/qa/rolling-upgrade/build.gradle @@ -14,7 +14,6 @@ dependencies { testCompile project(path: ':x-pack-elasticsearch:plugin', configuration: 'testArtifacts') } - Closure waitWithAuth = { NodeInfo node, AntBuilder ant -> File tmpFile = new File(node.cwd, 'wait.success') @@ -62,11 +61,15 @@ Closure waitWithAuth = { NodeInfo node, AntBuilder ant -> Project mainProject = project +compileTestJava.options.compilerArgs << "-Xlint:-cast,-deprecation,-rawtypes,-try,-unchecked" + /** * Subdirectories of this project are test rolling upgrades with various * configuration options based on their name. */ subprojects { + // TODO remove after backport + ext.bwc_tests_enabled = false Matcher m = project.name =~ /with(out)?-system-key/ if (false == m.matches()) { throw new InvalidUserDataException("Invalid project name [${project.name}]") @@ -268,6 +271,8 @@ subprojects { testCompile project(path: ':x-pack-elasticsearch:plugin', configuration: 'testArtifacts') } + compileTestJava.options.compilerArgs << "-Xlint:-cast,-deprecation,-rawtypes,-try,-unchecked" + // copy x-pack plugin info so it is on the classpath and security manager has the right permissions task copyXPackRestSpec(type: Copy) { dependsOn(project.configurations.restSpec, 'processTestResources') diff --git a/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java b/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java new file mode 100644 index 00000000000..5f5cf6a9b46 --- /dev/null +++ b/qa/rolling-upgrade/src/test/java/org/elasticsearch/upgrades/TokenBackwardsCompatibilityIT.java @@ -0,0 +1,307 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License; + * you may not use this file except in compliance with the Elastic License. + */ +package org.elasticsearch.upgrades; + +import org.apache.http.HttpHeaders; +import org.apache.http.HttpHost; +import org.apache.http.entity.ContentType; +import org.apache.http.entity.StringEntity; +import org.apache.http.message.BasicHeader; +import org.elasticsearch.Version; +import org.elasticsearch.client.Response; +import org.elasticsearch.client.ResponseException; +import org.elasticsearch.client.RestClient; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.test.SecuritySettingsSource; +import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.test.rest.yaml.ObjectPath; +import org.elasticsearch.xpack.security.SecurityLifecycleService; +import org.junit.Before; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.security.SecurityLifecycleService.SECURITY_TEMPLATE_NAME; +import static org.elasticsearch.xpack.security.authc.support.UsernamePasswordToken.basicAuthHeaderValue; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; + +public class TokenBackwardsCompatibilityIT extends ESRestTestCase { + + private static final String BASIC_AUTH_VALUE = + basicAuthHeaderValue("test_user", SecuritySettingsSource.TEST_PASSWORD_SECURE_STRING); + + @Override + protected boolean preserveIndicesUponCompletion() { + return true; + } + + @Override + protected boolean preserveReposUponCompletion() { + return true; + } + + @Override + protected boolean preserveTemplatesUponCompletion() { + return true; + } + + private enum CLUSTER_TYPE { + OLD, + MIXED, + UPGRADED; + + public static CLUSTER_TYPE parse(String value) { + switch (value) { + case "old_cluster": + return OLD; + case "mixed_cluster": + return MIXED; + case "upgraded_cluster": + return UPGRADED; + default: + throw new AssertionError("unknown cluster type: " + value); + } + } + } + + private final CLUSTER_TYPE clusterType = CLUSTER_TYPE.parse(System.getProperty("tests.rest.suite")); + + @Override + protected Settings restClientSettings() { + return Settings.builder() + .put(ThreadContext.PREFIX + ".Authorization", BASIC_AUTH_VALUE) + .build(); + } + + @Before + public void setupForTests() throws Exception { + final String template = SecurityLifecycleService.SECURITY_TEMPLATE_NAME; + awaitBusy(() -> { + try { + return adminClient().performRequest("HEAD", "_template/" + template).getStatusLine().getStatusCode() == 200; + } catch (IOException e) { + logger.warn("error calling template api", e); + } + return false; + }); + } + + public void testGeneratingTokenInOldCluster() throws Exception { + assumeTrue("this test should only run against the old cluster", clusterType == CLUSTER_TYPE.OLD); + final StringEntity tokenPostBody = new StringEntity("{\n" + + " \"username\": \"test_user\",\n" + + " \"password\": \"x-pack-test-password\",\n" + + " \"grant_type\": \"password\"\n" + + "}", ContentType.APPLICATION_JSON); + Response response = client().performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenPostBody); + assertOK(response); + Map responseMap = entityAsMap(response); + String token = (String) responseMap.get("access_token"); + assertNotNull(token); + assertTokenWorks(token); + + StringEntity oldClusterToken = new StringEntity("{\n" + + " \"token\": \"" + token + "\"\n" + + "}", ContentType.APPLICATION_JSON); + Response indexResponse = client().performRequest("PUT", "token_backwards_compatibility_it/doc/old_cluster_token1", + Collections.emptyMap(), oldClusterToken); + assertOK(indexResponse); + + response = client().performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenPostBody); + assertOK(response); + responseMap = entityAsMap(response); + token = (String) responseMap.get("access_token"); + assertNotNull(token); + assertTokenWorks(token); + oldClusterToken = new StringEntity("{\n" + + " \"token\": \"" + token + "\"\n" + + "}", ContentType.APPLICATION_JSON); + indexResponse = client().performRequest("PUT", "token_backwards_compatibility_it/doc/old_cluster_token2", + Collections.emptyMap(), oldClusterToken); + assertOK(indexResponse); + } + + public void testTokenWorksInMixedOrUpgradedCluster() throws Exception { + assumeTrue("this test should only run against the mixed or upgraded cluster", + clusterType == CLUSTER_TYPE.MIXED || clusterType == CLUSTER_TYPE.UPGRADED); + Response getResponse = client().performRequest("GET", "token_backwards_compatibility_it/doc/old_cluster_token1"); + assertOK(getResponse); + Map source = (Map) entityAsMap(getResponse).get("_source"); + assertTokenWorks((String) source.get("token")); + } + + public void testMixedCluster() throws Exception { + assumeTrue("this test should only run against the mixed cluster", clusterType == CLUSTER_TYPE.MIXED); + assumeTrue("the master must be on the latest version before we can write", isMasterOnLatestVersion()); + awaitIndexTemplateUpgrade(); + Response getResponse = client().performRequest("GET", "token_backwards_compatibility_it/doc/old_cluster_token2"); + assertOK(getResponse); + Map source = (Map) entityAsMap(getResponse).get("_source"); + final String token = (String) source.get("token"); + assertTokenWorks(token); + + final StringEntity body = new StringEntity("{\"token\": \"" + token + "\"}", ContentType.APPLICATION_JSON); + Response invalidationResponse = client().performRequest("DELETE", "_xpack/security/oauth2/token", Collections.emptyMap(), body); + assertOK(invalidationResponse); + assertTokenDoesNotWork(token); + + // create token and refresh on version that supports it + final StringEntity tokenPostBody = new StringEntity("{\n" + + " \"username\": \"test_user\",\n" + + " \"password\": \"x-pack-test-password\",\n" + + " \"grant_type\": \"password\"\n" + + "}", ContentType.APPLICATION_JSON); + try (RestClient client = getRestClientForCurrentVersionNodesOnly()) { + Response response = client.performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenPostBody); + assertOK(response); + Map responseMap = entityAsMap(response); + String accessToken = (String) responseMap.get("access_token"); + String refreshToken = (String) responseMap.get("refresh_token"); + assertNotNull(accessToken); + assertNotNull(refreshToken); + assertTokenWorks(accessToken); + + final StringEntity tokenRefresh = new StringEntity("{\n" + + " \"refresh_token\": \"" + refreshToken + "\",\n" + + " \"grant_type\": \"refresh_token\"\n" + + "}", ContentType.APPLICATION_JSON); + response = client.performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenRefresh); + assertOK(response); + responseMap = entityAsMap(response); + String updatedAccessToken = (String) responseMap.get("access_token"); + String updatedRefreshToken = (String) responseMap.get("refresh_token"); + assertNotNull(updatedAccessToken); + assertNotNull(updatedRefreshToken); + assertTokenWorks(updatedAccessToken); + assertTokenWorks(accessToken); + assertNotEquals(accessToken, updatedAccessToken); + assertNotEquals(refreshToken, updatedRefreshToken); + } + } + + public void testUpgradedCluster() throws Exception { + assumeTrue("this test should only run against the mixed cluster", clusterType == CLUSTER_TYPE.UPGRADED); + awaitIndexTemplateUpgrade(); + Response getResponse = client().performRequest("GET", "token_backwards_compatibility_it/doc/old_cluster_token2"); + assertOK(getResponse); + Map source = (Map) entityAsMap(getResponse).get("_source"); + final String token = (String) source.get("token"); + + // invalidate again since this may not have been invalidated in the mixed cluster + final StringEntity body = new StringEntity("{\"token\": \"" + token + "\"}", ContentType.APPLICATION_JSON); + Response invalidationResponse = client().performRequest("DELETE", "_xpack/security/oauth2/token", + Collections.singletonMap("error_trace", "true"), body); + assertOK(invalidationResponse); + assertTokenDoesNotWork(token); + + getResponse = client().performRequest("GET", "token_backwards_compatibility_it/doc/old_cluster_token1"); + assertOK(getResponse); + source = (Map) entityAsMap(getResponse).get("_source"); + final String workingToken = (String) source.get("token"); + assertTokenWorks(workingToken); + + final StringEntity tokenPostBody = new StringEntity("{\n" + + " \"username\": \"test_user\",\n" + + " \"password\": \"x-pack-test-password\",\n" + + " \"grant_type\": \"password\"\n" + + "}", ContentType.APPLICATION_JSON); + Response response = client().performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenPostBody); + assertOK(response); + Map responseMap = entityAsMap(response); + String accessToken = (String) responseMap.get("access_token"); + String refreshToken = (String) responseMap.get("refresh_token"); + assertNotNull(accessToken); + assertNotNull(refreshToken); + assertTokenWorks(accessToken); + + final StringEntity tokenRefresh = new StringEntity("{\n" + + " \"refresh_token\": \"" + refreshToken + "\",\n" + + " \"grant_type\": \"refresh_token\"\n" + + "}", ContentType.APPLICATION_JSON); + response = client().performRequest("POST", "_xpack/security/oauth2/token", Collections.emptyMap(), tokenRefresh); + assertOK(response); + responseMap = entityAsMap(response); + String updatedAccessToken = (String) responseMap.get("access_token"); + String updatedRefreshToken = (String) responseMap.get("refresh_token"); + assertNotNull(updatedAccessToken); + assertNotNull(updatedRefreshToken); + assertTokenWorks(updatedAccessToken); + assertTokenWorks(accessToken); + assertNotEquals(accessToken, updatedAccessToken); + assertNotEquals(refreshToken, updatedRefreshToken); + } + + private void assertTokenWorks(String token) throws IOException { + Response authenticateResponse = client().performRequest("GET", "_xpack/security/_authenticate", Collections.emptyMap(), + new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + token)); + assertOK(authenticateResponse); + assertEquals("test_user", entityAsMap(authenticateResponse).get("username")); + } + + private void assertTokenDoesNotWork(String token) { + ResponseException e = expectThrows(ResponseException.class, + () -> client().performRequest("GET", "_xpack/security/_authenticate", Collections.emptyMap(), + new BasicHeader(HttpHeaders.AUTHORIZATION, "Bearer " + token))); + assertEquals(401, e.getResponse().getStatusLine().getStatusCode()); + Response response = e.getResponse(); + assertEquals("Bearer realm=\"security\", error=\"invalid_token\", error_description=\"The access token expired\"", + response.getHeader("WWW-Authenticate")); + } + + private boolean isMasterOnLatestVersion() throws Exception { + Response response = client().performRequest("GET", "_cluster/state"); + assertOK(response); + final String masterNodeId = ObjectPath.createFromResponse(response).evaluate("master_node"); + response = client().performRequest("GET", "_nodes"); + assertOK(response); + ObjectPath objectPath = ObjectPath.createFromResponse(response); + return Version.CURRENT.equals(Version.fromString(objectPath.evaluate("nodes." + masterNodeId + ".version"))); + } + + private void awaitIndexTemplateUpgrade() throws Exception { + assertTrue(awaitBusy(() -> { + try { + Response response = client().performRequest("GET", "/_cluster/state/metadata"); + assertOK(response); + ObjectPath objectPath = ObjectPath.createFromResponse(response); + final String mappingsPath = "metadata.templates." + SECURITY_TEMPLATE_NAME + "" + + ".mappings"; + Map mappings = objectPath.evaluate(mappingsPath); + assertNotNull(mappings); + assertThat(mappings.size(), greaterThanOrEqualTo(1)); + String key = mappings.keySet().iterator().next(); + String templateVersion = objectPath.evaluate(mappingsPath + "." + key + "" + "._meta.security-version"); + final Version tVersion = Version.fromString(templateVersion); + return Version.CURRENT.equals(tVersion); + } catch (IOException e) { + logger.warn("caught exception checking template version", e); + return false; + } + })); + } + + private RestClient getRestClientForCurrentVersionNodesOnly() throws IOException { + Response response = client().performRequest("GET", "_nodes"); + assertOK(response); + ObjectPath objectPath = ObjectPath.createFromResponse(response); + Map nodesAsMap = objectPath.evaluate("nodes"); + List hosts = new ArrayList<>(); + for (Map.Entry entry : nodesAsMap.entrySet()) { + Map nodeDetails = (Map) entry.getValue(); + Version version = Version.fromString((String) nodeDetails.get("version")); + if (Version.CURRENT.equals(version)) { + Map httpInfo = (Map) nodeDetails.get("http"); + hosts.add(HttpHost.create((String) httpInfo.get("publish_address"))); + } + } + + return buildClient(restClientSettings(), hosts.toArray(new HttpHost[0])); + } +}