Add Jwt Client Authentication support

Closes gh-8175
This commit is contained in:
Joe Grandja 2020-11-16 19:56:45 -05:00
parent 224160f1ee
commit 9c97970e26
36 changed files with 3385 additions and 430 deletions

View File

@ -49,4 +49,5 @@
<suppress files="WebSocketMessageBrokerConfigTests\.java" checks="SpringMethodVisibility"/>
<suppress files="WebSecurityConfigurationTests\.java" checks="SpringMethodVisibility"/>
<suppress files="WithSecurityContextTestExecutionListenerTests\.java" checks="SpringMethodVisibility"/>
<suppress files="AbstractOAuth2AuthorizationGrantRequestEntityConverter\.java" checks="SpringMethodVisibility"/>
</suppressions>

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.util.Assert;
@ -27,6 +28,7 @@ import org.springframework.util.Assert;
* @author Joe Grandja
* @since 5.0
* @see AuthorizationGrantType
* @see ClientRegistration
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc6749#section-1.3">Section
* 1.3 Authorization Grant</a>
*/
@ -34,13 +36,34 @@ public abstract class AbstractOAuth2AuthorizationGrantRequest {
private final AuthorizationGrantType authorizationGrantType;
private final ClientRegistration clientRegistration;
/**
* Sub-class constructor.
* @param authorizationGrantType the authorization grant type
* @deprecated Use
* {@link #AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType, ClientRegistration)}
* instead
*/
@Deprecated
protected AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType authorizationGrantType) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
this.authorizationGrantType = authorizationGrantType;
this.clientRegistration = null;
}
/**
* Sub-class constructor.
* @param authorizationGrantType the authorization grant type
* @param clientRegistration the client registration
* @since 5.5
*/
protected AbstractOAuth2AuthorizationGrantRequest(AuthorizationGrantType authorizationGrantType,
ClientRegistration clientRegistration) {
Assert.notNull(authorizationGrantType, "authorizationGrantType cannot be null");
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
this.authorizationGrantType = authorizationGrantType;
this.clientRegistration = clientRegistration;
}
/**
@ -51,4 +74,13 @@ public abstract class AbstractOAuth2AuthorizationGrantRequest {
return this.authorizationGrantType;
}
/**
* Returns the {@link ClientRegistration client registration}.
* @return the {@link ClientRegistration}
* @since 5.5
*/
public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}
}

View File

@ -0,0 +1,173 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;
/**
* Base implementation of a {@link Converter} that converts the provided
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link RequestEntity}
* representation of an OAuth 2.0 Access Token Request for the Authorization Grant.
*
* @param <T> the type of {@link AbstractOAuth2AuthorizationGrantRequest}
* @author Joe Grandja
* @since 5.5
* @see Converter
* @see AbstractOAuth2AuthorizationGrantRequest
* @see RequestEntity
*/
abstract class AbstractOAuth2AuthorizationGrantRequestEntityConverter<T extends AbstractOAuth2AuthorizationGrantRequest>
implements Converter<T, RequestEntity<?>> {
// @formatter:off
private Converter<T, HttpHeaders> headersConverter =
(authorizationGrantRequest) -> OAuth2AuthorizationGrantRequestEntityUtils
.getTokenRequestHeaders(authorizationGrantRequest.getClientRegistration());
// @formatter:on
private Converter<T, MultiValueMap<String, String>> parametersConverter = this::createParameters;
@Override
public RequestEntity<?> convert(T authorizationGrantRequest) {
HttpHeaders headers = getHeadersConverter().convert(authorizationGrantRequest);
MultiValueMap<String, String> parameters = getParametersConverter().convert(authorizationGrantRequest);
URI uri = UriComponentsBuilder
.fromUriString(authorizationGrantRequest.getClientRegistration().getProviderDetails().getTokenUri())
.build().toUri();
return new RequestEntity<>(parameters, headers, HttpMethod.POST, uri);
}
/**
* Returns a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body.
* @param authorizationGrantRequest the authorization grant request
* @return a {@link MultiValueMap} of the parameters used in the OAuth 2.0 Access
* Token Request body
*/
abstract MultiValueMap<String, String> createParameters(T authorizationGrantRequest);
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @return the {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to {@link HttpHeaders}
*/
final Converter<T, HttpHeaders> getHeadersConverter() {
return this.headersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @param headersConverter the {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to {@link HttpHeaders}
*/
public final void setHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
Assert.notNull(headersConverter, "headersConverter cannot be null");
this.headersConverter = headersConverter;
}
/**
* Add (compose) the provided {@code headersConverter} to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link HttpHeaders}
* used in the OAuth 2.0 Access Token Request headers.
* @param headersConverter the {@link Converter} to add (compose) to the current
* {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to a {@link HttpHeaders}
*/
public final void addHeadersConverter(Converter<T, HttpHeaders> headersConverter) {
Assert.notNull(headersConverter, "headersConverter cannot be null");
Converter<T, HttpHeaders> currentHeadersConverter = this.headersConverter;
this.headersConverter = (authorizationGrantRequest) -> {
// Append headers using a Composite Converter
HttpHeaders headers = currentHeadersConverter.convert(authorizationGrantRequest);
if (headers == null) {
headers = new HttpHeaders();
}
HttpHeaders headersToAdd = headersConverter.convert(authorizationGrantRequest);
if (headersToAdd != null) {
headers.addAll(headersToAdd);
}
return headers;
};
}
/**
* Returns the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* of the parameters used in the OAuth 2.0 Access Token Request body.
* @return the {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the
* parameters
*/
final Converter<T, MultiValueMap<String, String>> getParametersConverter() {
return this.parametersConverter;
}
/**
* Sets the {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* of the parameters used in the OAuth 2.0 Access Token Request body.
* @param parametersConverter the {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the
* parameters
*/
public final void setParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
this.parametersConverter = parametersConverter;
}
/**
* Add (compose) the provided {@code parametersConverter} to the current
* {@link Converter} used for converting the
* {@link AbstractOAuth2AuthorizationGrantRequest} instance to a {@link MultiValueMap}
* of the parameters used in the OAuth 2.0 Access Token Request body.
* @param parametersConverter the {@link Converter} to add (compose) to the current
* {@link Converter} used for converting the
* {@link OAuth2AuthorizationCodeGrantRequest} to a {@link MultiValueMap} of the
* parameters
*/
public final void addParametersConverter(Converter<T, MultiValueMap<String, String>> parametersConverter) {
Assert.notNull(parametersConverter, "parametersConverter cannot be null");
Converter<T, MultiValueMap<String, String>> currentParametersConverter = this.parametersConverter;
this.parametersConverter = (authorizationGrantRequest) -> {
// Append parameters using a Composite Converter
MultiValueMap<String, String> parameters = currentParametersConverter.convert(authorizationGrantRequest);
if (parameters == null) {
parameters = new LinkedMultiValueMap<>();
}
MultiValueMap<String, String> parametersToAdd = parametersConverter.convert(authorizationGrantRequest);
if (parametersToAdd != null) {
parameters.addAll(parametersToAdd);
}
return parameters;
};
}
}

View File

@ -0,0 +1,390 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.net.URL;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Consumer;
import org.springframework.security.oauth2.core.converter.ClaimConversionService;
import org.springframework.security.oauth2.jose.JwaAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* The JOSE header is a JSON object representing the header parameters of a JSON Web
* Token, whether the JWT is a JWS or JWE, that describe the cryptographic operations
* applied to the JWT and optionally, additional properties of the JWT.
*
* @author Anoop Garlapati
* @author Joe Grandja
* @since 5.5
* @see Jwt
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-5">JWT JOSE
* Header</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-4">JWS JOSE
* Header</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7516#section-4">JWE JOSE
* Header</a>
*/
final class JoseHeader {
private final Map<String, Object> headers;
private JoseHeader(Map<String, Object> headers) {
this.headers = Collections.unmodifiableMap(new HashMap<>(headers));
}
/**
* Returns the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or
* encrypt the JWE.
* @return the {@link JwaAlgorithm}
*/
@SuppressWarnings("unchecked")
<T extends JwaAlgorithm> T getAlgorithm() {
return (T) getHeader(JoseHeaderNames.ALG);
}
/**
* Returns the JWK Set URL that refers to the resource of a set of JSON-encoded public
* keys, one of which corresponds to the key used to digitally sign the JWS or encrypt
* the JWE.
* @return the JWK Set URL
*/
URL getJwkSetUrl() {
return getHeader(JoseHeaderNames.JKU);
}
/**
* Returns the JSON Web Key which is the public key that corresponds to the key used
* to digitally sign the JWS or encrypt the JWE.
* @return the JSON Web Key
*/
Map<String, Object> getJwk() {
return getHeader(JoseHeaderNames.JWK);
}
/**
* Returns the key ID that is a hint indicating which key was used to secure the JWS
* or JWE.
* @return the key ID
*/
String getKeyId() {
return getHeader(JoseHeaderNames.KID);
}
/**
* Returns the X.509 URL that refers to the resource for the X.509 public key
* certificate or certificate chain corresponding to the key used to digitally sign
* the JWS or encrypt the JWE.
* @return the X.509 URL
*/
URL getX509Url() {
return getHeader(JoseHeaderNames.X5U);
}
/**
* Returns the X.509 certificate chain that contains the X.509 public key certificate
* or certificate chain corresponding to the key used to digitally sign the JWS or
* encrypt the JWE. The certificate or certificate chain is represented as a
* {@code List} of certificate value {@code String}s. Each {@code String} in the
* {@code List} is a Base64-encoded DER PKIX certificate value.
* @return the X.509 certificate chain
*/
List<String> getX509CertificateChain() {
return getHeader(JoseHeaderNames.X5C);
}
/**
* Returns the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1
* thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate
* corresponding to the key used to digitally sign the JWS or encrypt the JWE.
* @return the X.509 certificate SHA-1 thumbprint
*/
String getX509SHA1Thumbprint() {
return getHeader(JoseHeaderNames.X5T);
}
/**
* Returns the X.509 certificate SHA-256 thumbprint that is a base64url-encoded
* SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate
* corresponding to the key used to digitally sign the JWS or encrypt the JWE.
* @return the X.509 certificate SHA-256 thumbprint
*/
String getX509SHA256Thumbprint() {
return getHeader(JoseHeaderNames.X5T_S256);
}
/**
* Returns the type header that declares the media type of the JWS/JWE.
* @return the type header
*/
String getType() {
return getHeader(JoseHeaderNames.TYP);
}
/**
* Returns the content type header that declares the media type of the secured content
* (the payload).
* @return the content type header
*/
String getContentType() {
return getHeader(JoseHeaderNames.CTY);
}
/**
* Returns the critical headers that indicates which extensions to the JWS/JWE/JWA
* specifications are being used that MUST be understood and processed.
* @return the critical headers
*/
Set<String> getCritical() {
return getHeader(JoseHeaderNames.CRIT);
}
/**
* Returns the headers.
* @return the headers
*/
Map<String, Object> getHeaders() {
return this.headers;
}
/**
* Returns the header value.
* @param name the header name
* @param <T> the type of the header value
* @return the header value
*/
@SuppressWarnings("unchecked")
<T> T getHeader(String name) {
Assert.hasText(name, "name cannot be empty");
return (T) getHeaders().get(name);
}
/**
* Returns a new {@link Builder}, initialized with the provided {@link JwaAlgorithm}.
* @param jwaAlgorithm the {@link JwaAlgorithm}
* @return the {@link Builder}
*/
static Builder withAlgorithm(JwaAlgorithm jwaAlgorithm) {
return new Builder(jwaAlgorithm);
}
/**
* Returns a new {@link Builder}, initialized with the provided {@code headers}.
* @param headers the headers
* @return the {@link Builder}
*/
static Builder from(JoseHeader headers) {
return new Builder(headers);
}
/**
* A builder for {@link JoseHeader}.
*/
static final class Builder {
final Map<String, Object> headers = new HashMap<>();
private Builder(JwaAlgorithm jwaAlgorithm) {
algorithm(jwaAlgorithm);
}
private Builder(JoseHeader headers) {
Assert.notNull(headers, "headers cannot be null");
this.headers.putAll(headers.getHeaders());
}
/**
* Sets the {@link JwaAlgorithm JWA algorithm} used to digitally sign the JWS or
* encrypt the JWE.
* @param jwaAlgorithm the {@link JwaAlgorithm}
* @return the {@link Builder}
*/
Builder algorithm(JwaAlgorithm jwaAlgorithm) {
Assert.notNull(jwaAlgorithm, "jwaAlgorithm cannot be null");
return header(JoseHeaderNames.ALG, jwaAlgorithm);
}
/**
* Sets the JWK Set URL that refers to the resource of a set of JSON-encoded
* public keys, one of which corresponds to the key used to digitally sign the JWS
* or encrypt the JWE.
* @param jwkSetUrl the JWK Set URL
* @return the {@link Builder}
*/
Builder jwkSetUrl(String jwkSetUrl) {
return header(JoseHeaderNames.JKU, convertAsURL(JoseHeaderNames.JKU, jwkSetUrl));
}
/**
* Sets the JSON Web Key which is the public key that corresponds to the key used
* to digitally sign the JWS or encrypt the JWE.
* @param jwk the JSON Web Key
* @return the {@link Builder}
*/
Builder jwk(Map<String, Object> jwk) {
return header(JoseHeaderNames.JWK, jwk);
}
/**
* Sets the key ID that is a hint indicating which key was used to secure the JWS
* or JWE.
* @param keyId the key ID
* @return the {@link Builder}
*/
Builder keyId(String keyId) {
return header(JoseHeaderNames.KID, keyId);
}
/**
* Sets the X.509 URL that refers to the resource for the X.509 public key
* certificate or certificate chain corresponding to the key used to digitally
* sign the JWS or encrypt the JWE.
* @param x509Url the X.509 URL
* @return the {@link Builder}
*/
Builder x509Url(String x509Url) {
return header(JoseHeaderNames.X5U, convertAsURL(JoseHeaderNames.X5U, x509Url));
}
/**
* Sets the X.509 certificate chain that contains the X.509 public key certificate
* or certificate chain corresponding to the key used to digitally sign the JWS or
* encrypt the JWE. The certificate or certificate chain is represented as a
* {@code List} of certificate value {@code String}s. Each {@code String} in the
* {@code List} is a Base64-encoded DER PKIX certificate value.
* @param x509CertificateChain the X.509 certificate chain
* @return the {@link Builder}
*/
Builder x509CertificateChain(List<String> x509CertificateChain) {
return header(JoseHeaderNames.X5C, x509CertificateChain);
}
/**
* Sets the X.509 certificate SHA-1 thumbprint that is a base64url-encoded SHA-1
* thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate
* corresponding to the key used to digitally sign the JWS or encrypt the JWE.
* @param x509SHA1Thumbprint the X.509 certificate SHA-1 thumbprint
* @return the {@link Builder}
*/
Builder x509SHA1Thumbprint(String x509SHA1Thumbprint) {
return header(JoseHeaderNames.X5T, x509SHA1Thumbprint);
}
/**
* Sets the X.509 certificate SHA-256 thumbprint that is a base64url-encoded
* SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate
* corresponding to the key used to digitally sign the JWS or encrypt the JWE.
* @param x509SHA256Thumbprint the X.509 certificate SHA-256 thumbprint
* @return the {@link Builder}
*/
Builder x509SHA256Thumbprint(String x509SHA256Thumbprint) {
return header(JoseHeaderNames.X5T_S256, x509SHA256Thumbprint);
}
/**
* Sets the type header that declares the media type of the JWS/JWE.
* @param type the type header
* @return the {@link Builder}
*/
Builder type(String type) {
return header(JoseHeaderNames.TYP, type);
}
/**
* Sets the content type header that declares the media type of the secured
* content (the payload).
* @param contentType the content type header
* @return the {@link Builder}
*/
Builder contentType(String contentType) {
return header(JoseHeaderNames.CTY, contentType);
}
/**
* Sets the critical headers that indicates which extensions to the JWS/JWE/JWA
* specifications are being used that MUST be understood and processed.
* @param headerNames the critical header names
* @return the {@link Builder}
*/
Builder critical(Set<String> headerNames) {
return header(JoseHeaderNames.CRIT, headerNames);
}
/**
* Sets the header.
* @param name the header name
* @param value the header value
* @return the {@link Builder}
*/
Builder header(String name, Object value) {
Assert.hasText(name, "name cannot be empty");
Assert.notNull(value, "value cannot be null");
this.headers.put(name, value);
return this;
}
/**
* A {@code Consumer} to be provided access to the headers allowing the ability to
* add, replace, or remove.
* @param headersConsumer a {@code Consumer} of the headers
* @return the {@link Builder}
*/
Builder headers(Consumer<Map<String, Object>> headersConsumer) {
headersConsumer.accept(this.headers);
return this;
}
/**
* Builds a new {@link JoseHeader}.
* @return a {@link JoseHeader}
*/
JoseHeader build() {
Assert.notEmpty(this.headers, "headers cannot be empty");
return new JoseHeader(this.headers);
}
private static URL convertAsURL(String header, String value) {
URL convertedValue = ClaimConversionService.getSharedInstance().convert(value, URL.class);
Assert.isTrue(convertedValue != null,
() -> "Unable to convert header '" + header + "' of type '" + value.getClass() + "' to URL.");
return convertedValue;
}
}
}

View File

@ -0,0 +1,127 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* The Registered Header Parameter Names defined by the JSON Web Token (JWT), JSON Web
* Signature (JWS) and JSON Web Encryption (JWE) specifications that may be contained in
* the JOSE Header of a JWT.
*
* @author Anoop Garlapati
* @author Joe Grandja
* @since 5.5
* @see JoseHeader
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-5">JWT JOSE
* Header</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-4">JWS JOSE
* Header</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7516#section-4">JWE JOSE
* Header</a>
*/
final class JoseHeaderNames {
/**
* {@code alg} - the algorithm header identifies the cryptographic algorithm used to
* secure a JWS or JWE
*/
static final String ALG = "alg";
/**
* {@code jku} - the JWK Set URL header is a URI that refers to a resource for a set
* of JSON-encoded public keys, one of which corresponds to the key used to digitally
* sign a JWS or encrypt a JWE
*/
static final String JKU = "jku";
/**
* {@code jwk} - the JSON Web Key header is the public key that corresponds to the key
* used to digitally sign a JWS or encrypt a JWE
*/
static final String JWK = "jwk";
/**
* {@code kid} - the key ID header is a hint indicating which key was used to secure a
* JWS or JWE
*/
static final String KID = "kid";
/**
* {@code x5u} - the X.509 URL header is a URI that refers to a resource for the X.509
* public key certificate or certificate chain corresponding to the key used to
* digitally sign a JWS or encrypt a JWE
*/
static final String X5U = "x5u";
/**
* {@code x5c} - the X.509 certificate chain header contains the X.509 public key
* certificate or certificate chain corresponding to the key used to digitally sign a
* JWS or encrypt a JWE
*/
static final String X5C = "x5c";
/**
* {@code x5t} - the X.509 certificate SHA-1 thumbprint header is a base64url-encoded
* SHA-1 thumbprint (a.k.a. digest) of the DER encoding of the X.509 certificate
* corresponding to the key used to digitally sign a JWS or encrypt a JWE
*/
static final String X5T = "x5t";
/**
* {@code x5t#S256} - the X.509 certificate SHA-256 thumbprint header is a
* base64url-encoded SHA-256 thumbprint (a.k.a. digest) of the DER encoding of the
* X.509 certificate corresponding to the key used to digitally sign a JWS or encrypt
* a JWE
*/
static final String X5T_S256 = "x5t#S256";
/**
* {@code typ} - the type header is used by JWS/JWE applications to declare the media
* type of a JWS/JWE
*/
static final String TYP = "typ";
/**
* {@code cty} - the content type header is used by JWS/JWE applications to declare
* the media type of the secured content (the payload)
*/
static final String CTY = "cty";
/**
* {@code crit} - the critical header indicates that extensions to the JWS/JWE/JWA
* specifications are being used that MUST be understood and processed
*/
static final String CRIT = "crit";
private JoseHeaderNames() {
}
}

View File

@ -0,0 +1,222 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.net.URL;
import java.time.Instant;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.springframework.security.oauth2.core.converter.ClaimConversionService;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimAccessor;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.util.Assert;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* The {@link Jwt JWT} Claims Set is a JSON object representing the claims conveyed by a
* JSON Web Token.
*
* @author Anoop Garlapati
* @author Joe Grandja
* @since 5.5
* @see Jwt
* @see JwtClaimAccessor
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519#section-4">JWT Claims
* Set</a>
*/
final class JwtClaimsSet implements JwtClaimAccessor {
private final Map<String, Object> claims;
private JwtClaimsSet(Map<String, Object> claims) {
this.claims = Collections.unmodifiableMap(new HashMap<>(claims));
}
@Override
public Map<String, Object> getClaims() {
return this.claims;
}
/**
* Returns a new {@link Builder}.
* @return the {@link Builder}
*/
static Builder builder() {
return new Builder();
}
/**
* Returns a new {@link Builder}, initialized with the provided {@code claims}.
* @param claims a JWT claims set
* @return the {@link Builder}
*/
static Builder from(JwtClaimsSet claims) {
return new Builder(claims);
}
/**
* A builder for {@link JwtClaimsSet}.
*/
static final class Builder {
final Map<String, Object> claims = new HashMap<>();
private Builder() {
}
private Builder(JwtClaimsSet claims) {
Assert.notNull(claims, "claims cannot be null");
this.claims.putAll(claims.getClaims());
}
/**
* Sets the issuer {@code (iss)} claim, which identifies the principal that issued
* the JWT.
* @param issuer the issuer identifier
* @return the {@link Builder}
*/
Builder issuer(String issuer) {
return claim(JwtClaimNames.ISS, issuer);
}
/**
* Sets the subject {@code (sub)} claim, which identifies the principal that is
* the subject of the JWT.
* @param subject the subject identifier
* @return the {@link Builder}
*/
Builder subject(String subject) {
return claim(JwtClaimNames.SUB, subject);
}
/**
* Sets the audience {@code (aud)} claim, which identifies the recipient(s) that
* the JWT is intended for.
* @param audience the audience that this JWT is intended for
* @return the {@link Builder}
*/
Builder audience(List<String> audience) {
return claim(JwtClaimNames.AUD, audience);
}
/**
* Sets the expiration time {@code (exp)} claim, which identifies the time on or
* after which the JWT MUST NOT be accepted for processing.
* @param expiresAt the time on or after which the JWT MUST NOT be accepted for
* processing
* @return the {@link Builder}
*/
Builder expiresAt(Instant expiresAt) {
return claim(JwtClaimNames.EXP, expiresAt);
}
/**
* Sets the not before {@code (nbf)} claim, which identifies the time before which
* the JWT MUST NOT be accepted for processing.
* @param notBefore the time before which the JWT MUST NOT be accepted for
* processing
* @return the {@link Builder}
*/
Builder notBefore(Instant notBefore) {
return claim(JwtClaimNames.NBF, notBefore);
}
/**
* Sets the issued at {@code (iat)} claim, which identifies the time at which the
* JWT was issued.
* @param issuedAt the time at which the JWT was issued
* @return the {@link Builder}
*/
Builder issuedAt(Instant issuedAt) {
return claim(JwtClaimNames.IAT, issuedAt);
}
/**
* Sets the JWT ID {@code (jti)} claim, which provides a unique identifier for the
* JWT.
* @param jti the unique identifier for the JWT
* @return the {@link Builder}
*/
Builder id(String jti) {
return claim(JwtClaimNames.JTI, jti);
}
/**
* Sets the claim.
* @param name the claim name
* @param value the claim value
* @return the {@link Builder}
*/
Builder claim(String name, Object value) {
Assert.hasText(name, "name cannot be empty");
Assert.notNull(value, "value cannot be null");
this.claims.put(name, value);
return this;
}
/**
* A {@code Consumer} to be provided access to the claims allowing the ability to
* add, replace, or remove.
* @param claimsConsumer a {@code Consumer} of the claims
*/
Builder claims(Consumer<Map<String, Object>> claimsConsumer) {
claimsConsumer.accept(this.claims);
return this;
}
/**
* Builds a new {@link JwtClaimsSet}.
* @return a {@link JwtClaimsSet}
*/
JwtClaimsSet build() {
Assert.notEmpty(this.claims, "claims cannot be empty");
// The value of the 'iss' claim is a String or URL (StringOrURI).
// Attempt to convert to URL.
Object issuer = this.claims.get(JwtClaimNames.ISS);
if (issuer != null) {
URL convertedValue = ClaimConversionService.getSharedInstance().convert(issuer, URL.class);
if (convertedValue != null) {
this.claims.put(JwtClaimNames.ISS, convertedValue);
}
}
return new JwtClaimsSet(this.claims);
}
}
}

View File

@ -0,0 +1,62 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.springframework.security.oauth2.jwt.JwtException;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* This exception is thrown when an error occurs while attempting to encode a JSON Web
* Token (JWT).
*
* @author Joe Grandja
* @since 5.5
*/
class JwtEncodingException extends JwtException {
/**
* Constructs a {@code JwtEncodingException} using the provided parameters.
* @param message the detail message
*/
JwtEncodingException(String message) {
super(message);
}
/**
* Constructs a {@code JwtEncodingException} using the provided parameters.
* @param message the detail message
* @param cause the root cause
*/
JwtEncodingException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,359 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import java.net.URL;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObjectType;
import com.nimbusds.jose.JWSAlgorithm;
import com.nimbusds.jose.JWSHeader;
import com.nimbusds.jose.JWSSigner;
import com.nimbusds.jose.crypto.factories.DefaultJWSSignerFactory;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKMatcher;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.produce.JWSSignerFactory;
import com.nimbusds.jose.util.Base64;
import com.nimbusds.jose.util.Base64URL;
import com.nimbusds.jwt.JWTClaimsSet;
import com.nimbusds.jwt.SignedJWT;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* A JWT encoder that encodes a JSON Web Token (JWT) using the JSON Web Signature (JWS)
* Compact Serialization format. The private/secret key used for signing the JWS is
* supplied by the {@code com.nimbusds.jose.jwk.source.JWKSource} provided via the
* constructor.
*
* <p>
* <b>NOTE:</b> This implementation uses the Nimbus JOSE + JWT SDK.
*
* @author Joe Grandja
* @since 5.5
* @see com.nimbusds.jose.jwk.source.JWKSource
* @see com.nimbusds.jose.jwk.JWK
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7519">JSON Web Token
* (JWT)</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515">JSON Web Signature
* (JWS)</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7515#section-3.1">JWS
* Compact Serialization</a>
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus
* JOSE + JWT SDK</a>
*/
final class NimbusJwsEncoder {
private static final String ENCODING_ERROR_MESSAGE_TEMPLATE = "An error occurred while attempting to encode the Jwt: %s";
private static final JWSSignerFactory JWS_SIGNER_FACTORY = new DefaultJWSSignerFactory();
private final Map<JWK, JWSSigner> jwsSigners = new ConcurrentHashMap<>();
private final JWKSource<SecurityContext> jwkSource;
/**
* Constructs a {@code NimbusJwsEncoder} using the provided parameters.
* @param jwkSource the {@code com.nimbusds.jose.jwk.source.JWKSource}
*/
NimbusJwsEncoder(JWKSource<SecurityContext> jwkSource) {
Assert.notNull(jwkSource, "jwkSource cannot be null");
this.jwkSource = jwkSource;
}
Jwt encode(JoseHeader headers, JwtClaimsSet claims) throws JwtEncodingException {
Assert.notNull(headers, "headers cannot be null");
Assert.notNull(claims, "claims cannot be null");
JWK jwk = selectJwk(headers);
headers = addKeyIdentifierHeadersIfNecessary(headers, jwk);
String jws = serialize(headers, claims, jwk);
return new Jwt(jws, claims.getIssuedAt(), claims.getExpiresAt(), headers.getHeaders(), claims.getClaims());
}
private JWK selectJwk(JoseHeader headers) {
List<JWK> jwks;
try {
JWKSelector jwkSelector = new JWKSelector(createJwkMatcher(headers));
jwks = this.jwkSource.get(jwkSelector, null);
}
catch (Exception ex) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Failed to select a JWK signing key -> " + ex.getMessage()), ex);
}
if (jwks.size() > 1) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Found multiple JWK signing keys for algorithm '" + headers.getAlgorithm().getName() + "'"));
}
if (jwks.isEmpty()) {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to select a JWK signing key"));
}
return jwks.get(0);
}
private String serialize(JoseHeader headers, JwtClaimsSet claims, JWK jwk) {
JWSHeader jwsHeader = convert(headers);
JWTClaimsSet jwtClaimsSet = convert(claims);
JWSSigner jwsSigner = this.jwsSigners.computeIfAbsent(jwk, NimbusJwsEncoder::createSigner);
SignedJWT signedJwt = new SignedJWT(jwsHeader, jwtClaimsSet);
try {
signedJwt.sign(jwsSigner);
}
catch (JOSEException ex) {
throw new JwtEncodingException(
String.format(ENCODING_ERROR_MESSAGE_TEMPLATE, "Failed to sign the JWT -> " + ex.getMessage()), ex);
}
return signedJwt.serialize();
}
private static JWKMatcher createJwkMatcher(JoseHeader headers) {
JWSAlgorithm jwsAlgorithm = JWSAlgorithm.parse(headers.getAlgorithm().getName());
if (JWSAlgorithm.Family.RSA.contains(jwsAlgorithm) || JWSAlgorithm.Family.EC.contains(jwsAlgorithm)) {
// @formatter:off
return new JWKMatcher.Builder()
.keyType(KeyType.forAlgorithm(jwsAlgorithm))
.keyID(headers.getKeyId())
.keyUses(KeyUse.SIGNATURE, null)
.algorithms(jwsAlgorithm, null)
.x509CertSHA256Thumbprint(Base64URL.from(headers.getX509SHA256Thumbprint()))
.build();
// @formatter:on
}
else if (JWSAlgorithm.Family.HMAC_SHA.contains(jwsAlgorithm)) {
// @formatter:off
return new JWKMatcher.Builder()
.keyType(KeyType.forAlgorithm(jwsAlgorithm))
.keyID(headers.getKeyId())
.privateOnly(true)
.algorithms(jwsAlgorithm, null)
.build();
// @formatter:on
}
return null;
}
private static JoseHeader addKeyIdentifierHeadersIfNecessary(JoseHeader headers, JWK jwk) {
// Check if headers have already been added
if (StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(headers.getX509SHA256Thumbprint())) {
return headers;
}
// Check if headers can be added from JWK
if (!StringUtils.hasText(jwk.getKeyID()) && jwk.getX509CertSHA256Thumbprint() == null) {
return headers;
}
JoseHeader.Builder headersBuilder = JoseHeader.from(headers);
if (!StringUtils.hasText(headers.getKeyId()) && StringUtils.hasText(jwk.getKeyID())) {
headersBuilder.keyId(jwk.getKeyID());
}
if (!StringUtils.hasText(headers.getX509SHA256Thumbprint()) && jwk.getX509CertSHA256Thumbprint() != null) {
headersBuilder.x509SHA256Thumbprint(jwk.getX509CertSHA256Thumbprint().toString());
}
return headersBuilder.build();
}
private static JWSSigner createSigner(JWK jwk) {
try {
return JWS_SIGNER_FACTORY.createJWSSigner(jwk);
}
catch (JOSEException ex) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Failed to create a JWS Signer -> " + ex.getMessage()), ex);
}
}
private static JWSHeader convert(JoseHeader headers) {
JWSHeader.Builder builder = new JWSHeader.Builder(JWSAlgorithm.parse(headers.getAlgorithm().getName()));
if (headers.getJwkSetUrl() != null) {
builder.jwkURL(convertAsURI(JoseHeaderNames.JKU, headers.getJwkSetUrl()));
}
Map<String, Object> jwk = headers.getJwk();
if (!CollectionUtils.isEmpty(jwk)) {
try {
builder.jwk(JWK.parse(jwk));
}
catch (Exception ex) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Unable to convert '" + JoseHeaderNames.JWK + "' JOSE header"), ex);
}
}
String keyId = headers.getKeyId();
if (StringUtils.hasText(keyId)) {
builder.keyID(keyId);
}
if (headers.getX509Url() != null) {
builder.x509CertURL(convertAsURI(JoseHeaderNames.X5U, headers.getX509Url()));
}
List<String> x509CertificateChain = headers.getX509CertificateChain();
if (!CollectionUtils.isEmpty(x509CertificateChain)) {
List<Base64> x5cList = new ArrayList<>();
x509CertificateChain.forEach((x5c) -> x5cList.add(new Base64(x5c)));
if (!x5cList.isEmpty()) {
builder.x509CertChain(x5cList);
}
}
String x509SHA1Thumbprint = headers.getX509SHA1Thumbprint();
if (StringUtils.hasText(x509SHA1Thumbprint)) {
builder.x509CertThumbprint(new Base64URL(x509SHA1Thumbprint));
}
String x509SHA256Thumbprint = headers.getX509SHA256Thumbprint();
if (StringUtils.hasText(x509SHA256Thumbprint)) {
builder.x509CertSHA256Thumbprint(new Base64URL(x509SHA256Thumbprint));
}
String type = headers.getType();
if (StringUtils.hasText(type)) {
builder.type(new JOSEObjectType(type));
}
String contentType = headers.getContentType();
if (StringUtils.hasText(contentType)) {
builder.contentType(contentType);
}
Set<String> critical = headers.getCritical();
if (!CollectionUtils.isEmpty(critical)) {
builder.criticalParams(critical);
}
Map<String, Object> customHeaders = new HashMap<>();
headers.getHeaders().forEach((name, value) -> {
if (!JWSHeader.getRegisteredParameterNames().contains(name)) {
customHeaders.put(name, value);
}
});
if (!customHeaders.isEmpty()) {
builder.customParams(customHeaders);
}
return builder.build();
}
private static JWTClaimsSet convert(JwtClaimsSet claims) {
JWTClaimsSet.Builder builder = new JWTClaimsSet.Builder();
// NOTE: The value of the 'iss' claim is a String or URL (StringOrURI).
Object issuer = claims.getClaim(JwtClaimNames.ISS);
if (issuer != null) {
builder.issuer(issuer.toString());
}
String subject = claims.getSubject();
if (StringUtils.hasText(subject)) {
builder.subject(subject);
}
List<String> audience = claims.getAudience();
if (!CollectionUtils.isEmpty(audience)) {
builder.audience(audience);
}
Instant expiresAt = claims.getExpiresAt();
if (expiresAt != null) {
builder.expirationTime(Date.from(expiresAt));
}
Instant notBefore = claims.getNotBefore();
if (notBefore != null) {
builder.notBeforeTime(Date.from(notBefore));
}
Instant issuedAt = claims.getIssuedAt();
if (issuedAt != null) {
builder.issueTime(Date.from(issuedAt));
}
String jwtId = claims.getId();
if (StringUtils.hasText(jwtId)) {
builder.jwtID(jwtId);
}
Map<String, Object> customClaims = new HashMap<>();
claims.getClaims().forEach((name, value) -> {
if (!JWTClaimsSet.getRegisteredNames().contains(name)) {
customClaims.put(name, value);
}
});
if (!customClaims.isEmpty()) {
customClaims.forEach(builder::claim);
}
return builder.build();
}
private static URI convertAsURI(String header, URL url) {
try {
return url.toURI();
}
catch (Exception ex) {
throw new JwtEncodingException(String.format(ENCODING_ERROR_MESSAGE_TEMPLATE,
"Unable to convert '" + header + "' JOSE header to a URI"), ex);
}
}
}

View File

@ -0,0 +1,183 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.time.Duration;
import java.time.Instant;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyType;
import com.nimbusds.jose.jwk.source.ImmutableJWKSet;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import org.springframework.core.convert.converter.Converter;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.OAuth2Error;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.jws.JwsAlgorithm;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
/**
* A {@link Converter} that customizes the OAuth 2.0 Access Token Request parameters by
* adding a signed JSON Web Token (JWS) to be used for client authentication at the
* Authorization Server's Token Endpoint. The private/secret key used for signing the JWS
* is supplied by the {@code com.nimbusds.jose.jwk.JWK} resolver provided via the
* constructor.
*
* <p>
* <b>NOTE:</b> This implementation uses the Nimbus JOSE + JWT SDK.
*
* @param <T> the type of {@link AbstractOAuth2AuthorizationGrantRequest}
* @author Joe Grandja
* @since 5.5
* @see Converter
* @see com.nimbusds.jose.jwk.JWK
* @see OAuth2AuthorizationCodeGrantRequestEntityConverter#addParametersConverter(Converter)
* @see OAuth2ClientCredentialsGrantRequestEntityConverter#addParametersConverter(Converter)
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7523#section-2.2">2.2
* Using JWTs for Client Authentication</a>
* @see <a target="_blank" href="https://tools.ietf.org/html/rfc7521#section-4.2">4.2
* Using Assertions for Client Authentication</a>
* @see <a target="_blank" href="https://connect2id.com/products/nimbus-jose-jwt">Nimbus
* JOSE + JWT SDK</a>
*/
public final class NimbusJwtClientAuthenticationParametersConverter<T extends AbstractOAuth2AuthorizationGrantRequest>
implements Converter<T, MultiValueMap<String, String>> {
private static final String INVALID_KEY_ERROR_CODE = "invalid_key";
private static final String INVALID_ALGORITHM_ERROR_CODE = "invalid_algorithm";
private static final String CLIENT_ASSERTION_TYPE_VALUE = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer";
private final Function<ClientRegistration, JWK> jwkResolver;
private final Map<String, NimbusJwsEncoder> jwsEncoders = new ConcurrentHashMap<>();
/**
* Constructs a {@code NimbusJwtClientAuthenticationParametersConverter} using the
* provided parameters.
* @param jwkResolver the resolver that provides the {@code com.nimbusds.jose.jwk.JWK}
* associated to the {@link ClientRegistration client}
*/
public NimbusJwtClientAuthenticationParametersConverter(Function<ClientRegistration, JWK> jwkResolver) {
Assert.notNull(jwkResolver, "jwkResolver cannot be null");
this.jwkResolver = jwkResolver;
}
@Override
public MultiValueMap<String, String> convert(T authorizationGrantRequest) {
Assert.notNull(authorizationGrantRequest, "authorizationGrantRequest cannot be null");
ClientRegistration clientRegistration = authorizationGrantRequest.getClientRegistration();
if (!ClientAuthenticationMethod.PRIVATE_KEY_JWT.equals(clientRegistration.getClientAuthenticationMethod())
&& !ClientAuthenticationMethod.CLIENT_SECRET_JWT
.equals(clientRegistration.getClientAuthenticationMethod())) {
return null;
}
JWK jwk = this.jwkResolver.apply(clientRegistration);
if (jwk == null) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_KEY_ERROR_CODE,
"Failed to resolve JWK signing key for client registration '"
+ clientRegistration.getRegistrationId() + "'.",
null);
throw new OAuth2AuthorizationException(oauth2Error);
}
JwsAlgorithm jwsAlgorithm = resolveAlgorithm(jwk);
if (jwsAlgorithm == null) {
OAuth2Error oauth2Error = new OAuth2Error(INVALID_ALGORITHM_ERROR_CODE,
"Unable to resolve JWS (signing) algorithm from JWK associated to client registration '"
+ clientRegistration.getRegistrationId() + "'.",
null);
throw new OAuth2AuthorizationException(oauth2Error);
}
JoseHeader.Builder headersBuilder = JoseHeader.withAlgorithm(jwsAlgorithm);
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(Duration.ofSeconds(60));
// @formatter:off
JwtClaimsSet.Builder claimsBuilder = JwtClaimsSet.builder()
.issuer(clientRegistration.getClientId())
.subject(clientRegistration.getClientId())
.audience(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri()))
.id(UUID.randomUUID().toString())
.issuedAt(issuedAt)
.expiresAt(expiresAt);
// @formatter:on
JoseHeader joseHeader = headersBuilder.build();
JwtClaimsSet jwtClaimsSet = claimsBuilder.build();
NimbusJwsEncoder jwsEncoder = this.jwsEncoders.computeIfAbsent(clientRegistration.getRegistrationId(),
(clientRegistrationId) -> {
JWKSource<SecurityContext> jwkSource = new ImmutableJWKSet<>(new JWKSet(jwk));
return new NimbusJwsEncoder(jwkSource);
});
Jwt jws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE, CLIENT_ASSERTION_TYPE_VALUE);
parameters.set(OAuth2ParameterNames.CLIENT_ASSERTION, jws.getTokenValue());
return parameters;
}
private static JwsAlgorithm resolveAlgorithm(JWK jwk) {
JwsAlgorithm jwsAlgorithm = null;
if (jwk.getAlgorithm() != null) {
jwsAlgorithm = SignatureAlgorithm.from(jwk.getAlgorithm().getName());
if (jwsAlgorithm == null) {
jwsAlgorithm = MacAlgorithm.from(jwk.getAlgorithm().getName());
}
}
if (jwsAlgorithm == null) {
if (KeyType.RSA.equals(jwk.getKeyType())) {
jwsAlgorithm = SignatureAlgorithm.RS256;
}
else if (KeyType.EC.equals(jwk.getKeyType())) {
jwsAlgorithm = SignatureAlgorithm.ES256;
}
else if (KeyType.OCT.equals(jwk.getKeyType())) {
jwsAlgorithm = MacAlgorithm.HS256;
}
}
return jwsAlgorithm;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2017 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -37,8 +37,6 @@ import org.springframework.util.Assert;
*/
public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
private final ClientRegistration clientRegistration;
private final OAuth2AuthorizationExchange authorizationExchange;
/**
@ -49,21 +47,11 @@ public class OAuth2AuthorizationCodeGrantRequest extends AbstractOAuth2Authoriza
*/
public OAuth2AuthorizationCodeGrantRequest(ClientRegistration clientRegistration,
OAuth2AuthorizationExchange authorizationExchange) {
super(AuthorizationGrantType.AUTHORIZATION_CODE);
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
super(AuthorizationGrantType.AUTHORIZATION_CODE, clientRegistration);
Assert.notNull(authorizationExchange, "authorizationExchange cannot be null");
this.clientRegistration = clientRegistration;
this.authorizationExchange = authorizationExchange;
}
/**
* Returns the {@link ClientRegistration client registration}.
* @return the {@link ClientRegistration}
*/
public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}
/**
* Returns the {@link OAuth2AuthorizationExchange authorization exchange}.
* @return the {@link OAuth2AuthorizationExchange}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -29,69 +24,48 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A {@link Converter} that converts the provided
* {@link OAuth2AuthorizationCodeGrantRequest} to a {@link RequestEntity} representation
* of an OAuth 2.0 Access Token Request for the Authorization Code Grant.
* An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter}
* that converts the provided {@link OAuth2AuthorizationCodeGrantRequest} to a
* {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the
* Authorization Code Grant.
*
* @author Joe Grandja
* @since 5.1
* @see Converter
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2AuthorizationCodeGrantRequest
* @see RequestEntity
*/
public class OAuth2AuthorizationCodeGrantRequestEntityConverter
implements Converter<OAuth2AuthorizationCodeGrantRequest, RequestEntity<?>> {
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2AuthorizationCodeGrantRequest> {
/**
* Returns the {@link RequestEntity} used for the Access Token Request.
* @param authorizationCodeGrantRequest the authorization code grant request
* @return the {@link RequestEntity} used for the Access Token Request
*/
@Override
public RequestEntity<?> convert(OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration();
HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration);
MultiValueMap<String, String> formParameters = this.buildFormParameters(authorizationCodeGrantRequest);
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build()
.toUri();
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
}
/**
* Returns a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body.
* @param authorizationCodeGrantRequest the authorization code grant request
* @return a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body
*/
private MultiValueMap<String, String> buildFormParameters(
protected MultiValueMap<String, String> createParameters(
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest) {
ClientRegistration clientRegistration = authorizationCodeGrantRequest.getClientRegistration();
OAuth2AuthorizationExchange authorizationExchange = authorizationCodeGrantRequest.getAuthorizationExchange();
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue());
formParameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, authorizationCodeGrantRequest.getGrantType().getValue());
parameters.add(OAuth2ParameterNames.CODE, authorizationExchange.getAuthorizationResponse().getCode());
String redirectUri = authorizationExchange.getAuthorizationRequest().getRedirectUri();
String codeVerifier = authorizationExchange.getAuthorizationRequest()
.getAttribute(PkceParameterNames.CODE_VERIFIER);
if (redirectUri != null) {
formParameters.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
parameters.add(OAuth2ParameterNames.REDIRECT_URI, redirectUri);
}
if (!ClientAuthenticationMethod.CLIENT_SECRET_BASIC.equals(clientRegistration.getClientAuthenticationMethod())
&& !ClientAuthenticationMethod.BASIC.equals(clientRegistration.getClientAuthenticationMethod())) {
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
if (codeVerifier != null) {
formParameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier);
parameters.add(PkceParameterNames.CODE_VERIFIER, codeVerifier);
}
return formParameters;
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2018 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -34,27 +34,15 @@ import org.springframework.util.Assert;
*/
public class OAuth2ClientCredentialsGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
private final ClientRegistration clientRegistration;
/**
* Constructs an {@code OAuth2ClientCredentialsGrantRequest} using the provided
* parameters.
* @param clientRegistration the client registration
*/
public OAuth2ClientCredentialsGrantRequest(ClientRegistration clientRegistration) {
super(AuthorizationGrantType.CLIENT_CREDENTIALS);
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
super(AuthorizationGrantType.CLIENT_CREDENTIALS, clientRegistration);
Assert.isTrue(AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType()),
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");
this.clientRegistration = clientRegistration;
}
/**
* Returns the {@link ClientRegistration client registration}.
* @return the {@link ClientRegistration}
*/
public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -29,59 +24,38 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A {@link Converter} that converts the provided
* {@link OAuth2ClientCredentialsGrantRequest} to a {@link RequestEntity} representation
* of an OAuth 2.0 Access Token Request for the Client Credentials Grant.
* An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter}
* that converts the provided {@link OAuth2ClientCredentialsGrantRequest} to a
* {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the
* Client Credentials Grant.
*
* @author Joe Grandja
* @since 5.1
* @see Converter
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2ClientCredentialsGrantRequest
* @see RequestEntity
*/
public class OAuth2ClientCredentialsGrantRequestEntityConverter
implements Converter<OAuth2ClientCredentialsGrantRequest, RequestEntity<?>> {
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2ClientCredentialsGrantRequest> {
/**
* Returns the {@link RequestEntity} used for the Access Token Request.
* @param clientCredentialsGrantRequest the client credentials grant request
* @return the {@link RequestEntity} used for the Access Token Request
*/
@Override
public RequestEntity<?> convert(OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration);
MultiValueMap<String, String> formParameters = this.buildFormParameters(clientCredentialsGrantRequest);
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build()
.toUri();
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
}
/**
* Returns a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body.
* @param clientCredentialsGrantRequest the client credentials grant request
* @return a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body
*/
private MultiValueMap<String, String> buildFormParameters(
protected MultiValueMap<String, String> createParameters(
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest) {
ClientRegistration clientRegistration = clientCredentialsGrantRequest.getClientRegistration();
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, clientCredentialsGrantRequest.getGrantType().getValue());
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
formParameters.add(OAuth2ParameterNames.SCOPE,
parameters.add(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return formParameters;
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -33,8 +33,6 @@ import org.springframework.util.Assert;
*/
public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
private final ClientRegistration clientRegistration;
private final String username;
private final String password;
@ -46,25 +44,15 @@ public class OAuth2PasswordGrantRequest extends AbstractOAuth2AuthorizationGrant
* @param password the resource owner's password
*/
public OAuth2PasswordGrantRequest(ClientRegistration clientRegistration, String username, String password) {
super(AuthorizationGrantType.PASSWORD);
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
super(AuthorizationGrantType.PASSWORD, clientRegistration);
Assert.isTrue(AuthorizationGrantType.PASSWORD.equals(clientRegistration.getAuthorizationGrantType()),
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.PASSWORD");
Assert.hasText(username, "username cannot be empty");
Assert.hasText(password, "password cannot be empty");
this.clientRegistration = clientRegistration;
this.username = username;
this.password = password;
}
/**
* Returns the {@link ClientRegistration client registration}.
* @return the {@link ClientRegistration}
*/
public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}
/**
* Returns the resource owner's username.
* @return the resource owner's username

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -29,60 +24,39 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A {@link Converter} that converts the provided {@link OAuth2PasswordGrantRequest} to a
* An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter}
* that converts the provided {@link OAuth2PasswordGrantRequest} to a
* {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the
* Resource Owner Password Credentials Grant.
*
* @author Joe Grandja
* @since 5.2
* @see Converter
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2PasswordGrantRequest
* @see RequestEntity
*/
public class OAuth2PasswordGrantRequestEntityConverter
implements Converter<OAuth2PasswordGrantRequest, RequestEntity<?>> {
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2PasswordGrantRequest> {
/**
* Returns the {@link RequestEntity} used for the Access Token Request.
* @param passwordGrantRequest the password grant request
* @return the {@link RequestEntity} used for the Access Token Request
*/
@Override
public RequestEntity<?> convert(OAuth2PasswordGrantRequest passwordGrantRequest) {
protected MultiValueMap<String, String> createParameters(OAuth2PasswordGrantRequest passwordGrantRequest) {
ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration();
HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration);
MultiValueMap<String, String> formParameters = buildFormParameters(passwordGrantRequest);
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build()
.toUri();
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
}
/**
* Returns a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body.
* @param passwordGrantRequest the password grant request
* @return a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body
*/
private MultiValueMap<String, String> buildFormParameters(OAuth2PasswordGrantRequest passwordGrantRequest) {
ClientRegistration clientRegistration = passwordGrantRequest.getClientRegistration();
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue());
formParameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername());
formParameters.add(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, passwordGrantRequest.getGrantType().getValue());
parameters.add(OAuth2ParameterNames.USERNAME, passwordGrantRequest.getUsername());
parameters.add(OAuth2ParameterNames.PASSWORD, passwordGrantRequest.getPassword());
if (!CollectionUtils.isEmpty(clientRegistration.getScopes())) {
formParameters.add(OAuth2ParameterNames.SCOPE,
parameters.add(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(clientRegistration.getScopes(), " "));
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return formParameters;
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,8 +39,6 @@ import org.springframework.util.Assert;
*/
public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationGrantRequest {
private final ClientRegistration clientRegistration;
private final OAuth2AccessToken accessToken;
private final OAuth2RefreshToken refreshToken;
@ -67,25 +65,15 @@ public class OAuth2RefreshTokenGrantRequest extends AbstractOAuth2AuthorizationG
*/
public OAuth2RefreshTokenGrantRequest(ClientRegistration clientRegistration, OAuth2AccessToken accessToken,
OAuth2RefreshToken refreshToken, Set<String> scopes) {
super(AuthorizationGrantType.REFRESH_TOKEN);
Assert.notNull(clientRegistration, "clientRegistration cannot be null");
super(AuthorizationGrantType.REFRESH_TOKEN, clientRegistration);
Assert.notNull(accessToken, "accessToken cannot be null");
Assert.notNull(refreshToken, "refreshToken cannot be null");
this.clientRegistration = clientRegistration;
this.accessToken = accessToken;
this.refreshToken = refreshToken;
this.scopes = Collections
.unmodifiableSet((scopes != null) ? new LinkedHashSet<>(scopes) : Collections.emptySet());
}
/**
* Returns the authorized client's {@link ClientRegistration registration}.
* @return the {@link ClientRegistration}
*/
public ClientRegistration getClientRegistration() {
return this.clientRegistration;
}
/**
* Returns the {@link OAuth2AccessToken access token} credential granted.
* @return the {@link OAuth2AccessToken}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,6 @@
package org.springframework.security.oauth2.client.endpoint;
import java.net.URI;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@ -29,60 +24,38 @@ import org.springframework.util.CollectionUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponentsBuilder;
/**
* A {@link Converter} that converts the provided {@link OAuth2RefreshTokenGrantRequest}
* to a {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the
* An implementation of an {@link AbstractOAuth2AuthorizationGrantRequestEntityConverter}
* that converts the provided {@link OAuth2RefreshTokenGrantRequest} to a
* {@link RequestEntity} representation of an OAuth 2.0 Access Token Request for the
* Refresh Token Grant.
*
* @author Joe Grandja
* @since 5.2
* @see Converter
* @see AbstractOAuth2AuthorizationGrantRequestEntityConverter
* @see OAuth2RefreshTokenGrantRequest
* @see RequestEntity
*/
public class OAuth2RefreshTokenGrantRequestEntityConverter
implements Converter<OAuth2RefreshTokenGrantRequest, RequestEntity<?>> {
extends AbstractOAuth2AuthorizationGrantRequestEntityConverter<OAuth2RefreshTokenGrantRequest> {
/**
* Returns the {@link RequestEntity} used for the Access Token Request.
* @param refreshTokenGrantRequest the refresh token grant request
* @return the {@link RequestEntity} used for the Access Token Request
*/
@Override
public RequestEntity<?> convert(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
protected MultiValueMap<String, String> createParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
HttpHeaders headers = OAuth2AuthorizationGrantRequestEntityUtils.getTokenRequestHeaders(clientRegistration);
MultiValueMap<String, String> formParameters = buildFormParameters(refreshTokenGrantRequest);
URI uri = UriComponentsBuilder.fromUriString(clientRegistration.getProviderDetails().getTokenUri()).build()
.toUri();
return new RequestEntity<>(formParameters, headers, HttpMethod.POST, uri);
}
/**
* Returns a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body.
* @param refreshTokenGrantRequest the refresh token grant request
* @return a {@link MultiValueMap} of the form parameters used for the Access Token
* Request body
*/
private MultiValueMap<String, String> buildFormParameters(OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest) {
ClientRegistration clientRegistration = refreshTokenGrantRequest.getClientRegistration();
MultiValueMap<String, String> formParameters = new LinkedMultiValueMap<>();
formParameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue());
formParameters.add(OAuth2ParameterNames.REFRESH_TOKEN,
refreshTokenGrantRequest.getRefreshToken().getTokenValue());
MultiValueMap<String, String> parameters = new LinkedMultiValueMap<>();
parameters.add(OAuth2ParameterNames.GRANT_TYPE, refreshTokenGrantRequest.getGrantType().getValue());
parameters.add(OAuth2ParameterNames.REFRESH_TOKEN, refreshTokenGrantRequest.getRefreshToken().getTokenValue());
if (!CollectionUtils.isEmpty(refreshTokenGrantRequest.getScopes())) {
formParameters.add(OAuth2ParameterNames.SCOPE,
parameters.add(OAuth2ParameterNames.SCOPE,
StringUtils.collectionToDelimitedString(refreshTokenGrantRequest.getScopes(), " "));
}
if (ClientAuthenticationMethod.CLIENT_SECRET_POST.equals(clientRegistration.getClientAuthenticationMethod())
|| ClientAuthenticationMethod.POST.equals(clientRegistration.getClientAuthenticationMethod())) {
formParameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
formParameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
parameters.add(OAuth2ParameterNames.CLIENT_ID, clientRegistration.getClientId());
parameters.add(OAuth2ParameterNames.CLIENT_SECRET, clientRegistration.getClientSecret());
}
return formParameters;
return parameters;
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,13 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -29,7 +34,7 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
@ -37,6 +42,8 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenRespon
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -49,32 +56,25 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/
public class DefaultAuthorizationCodeTokenResponseClientTests {
private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
private DefaultAuthorizationCodeTokenResponseClient tokenResponseClient;
private ClientRegistration clientRegistration;
private ClientRegistration.Builder clientRegistration;
private MockWebServer server;
@Before
public void setup() throws Exception {
this.tokenResponseClient = new DefaultAuthorizationCodeTokenResponseClient();
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = ClientRegistration
.withRegistrationId("registration-1")
this.clientRegistration = TestClientRegistrations.clientRegistration()
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("https://client.com/callback/client-1")
.scope("read", "write")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri(tokenUri)
.userInfoUri("https://provider.com/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
.scope("read", "write");
// @formatter:on
}
@ -114,7 +114,7 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(this.authorizationCodeGrantRequest());
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()));
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getMethod()).isEqualTo(HttpMethod.POST.toString());
@ -136,7 +136,7 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception {
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
@ -145,13 +145,13 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest());
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()));
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic ");
}
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
@ -160,9 +160,9 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.from(this.clientRegistration)
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build();
this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration));
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration));
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
@ -170,6 +170,79 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration));
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration));
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2AuthorizationCodeGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
OAuth2AuthorizationCodeGrantRequestEntityConverter requestEntityConverter = new OAuth2AuthorizationCodeGrantRequestEntityConverter();
requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter);
this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -181,7 +254,8 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
.isThrownBy(() -> this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())))
.withMessageContaining(
"[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
.withMessageContaining("tokenType cannot be null");
@ -196,7 +270,8 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
.isThrownBy(() -> this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())))
.withMessageContaining(
"[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response")
.withMessageContaining("tokenType cannot be null");
@ -215,7 +290,7 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(this.authorizationCodeGrantRequest());
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()));
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read");
}
@ -231,16 +306,16 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(this.authorizationCodeGrantRequest());
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build()));
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write");
}
@Test
public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() {
String invalidTokenUri = "https://invalid-provider.com/oauth2/token";
ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build();
ClientRegistration clientRegistration = this.clientRegistration.tokenUri(invalidTokenUri).build();
assertThatExceptionOfType(OAuth2AuthorizationException.class).isThrownBy(
() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest(clientRegistration)))
() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest(clientRegistration)))
.withMessageContaining(
"[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response");
}
@ -260,7 +335,8 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
.isThrownBy(() -> this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())))
.withMessageContaining(
"[invalid_token_response] An error occurred while attempting to retrieve the OAuth 2.0 Access Token Response");
}
@ -270,7 +346,8 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
.isThrownBy(() -> this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())))
.withMessageContaining("[unauthorized_client]");
}
@ -278,15 +355,12 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
this.server.enqueue(new MockResponse().setResponseCode(500));
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
.isThrownBy(() -> this.tokenResponseClient
.getTokenResponse(authorizationCodeGrantRequest(this.clientRegistration.build())))
.withMessageContaining("[invalid_token_response] An error occurred while attempting to retrieve "
+ "the OAuth 2.0 Access Token Response");
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest() {
return this.authorizationCodeGrantRequest(this.clientRegistration);
}
private OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest(ClientRegistration clientRegistration) {
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId(clientRegistration.getClientId()).state("state-1234")
@ -303,22 +377,4 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private ClientRegistration.Builder from(ClientRegistration registration) {
// @formatter:off
return ClientRegistration.withRegistrationId(registration.getRegistrationId())
.clientId(registration.getClientId())
.clientSecret(registration.getClientSecret())
.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
.authorizationGrantType(registration.getAuthorizationGrantType())
.redirectUri(registration.getRedirectUri())
.scope(registration.getScopes())
.authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.tokenUri(registration.getProviderDetails().getTokenUri())
.userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri())
.userNameAttributeName(
registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName())
.clientName(registration.getClientName());
// @formatter:on
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,13 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -29,11 +34,13 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -46,26 +53,24 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/
public class DefaultClientCredentialsTokenResponseClientTests {
private DefaultClientCredentialsTokenResponseClient tokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
private DefaultClientCredentialsTokenResponseClient tokenResponseClient;
private ClientRegistration clientRegistration;
private ClientRegistration.Builder clientRegistration;
private MockWebServer server;
@Before
public void setup() throws Exception {
this.tokenResponseClient = new DefaultClientCredentialsTokenResponseClient();
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
// @formatter:off
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
this.clientRegistration = TestClientRegistrations.clientCredentials()
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri(tokenUri)
.build();
.scope("read", "write");
// @formatter:on
}
@ -110,7 +115,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(clientCredentialsGrantRequest);
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
@ -133,7 +138,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception {
public void getTokenResponseWhenAuthenticationClientSecretBasicThenAuthorizationHeaderIsSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
@ -143,14 +148,14 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).startsWith("Basic ");
}
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
@ -159,7 +164,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.from(this.clientRegistration)
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build();
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
@ -171,6 +176,83 @@ public class DefaultClientCredentialsTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
OAuth2ClientCredentialsGrantRequestEntityConverter requestEntityConverter = new OAuth2ClientCredentialsGrantRequestEntityConverter();
requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter);
this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -182,7 +264,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
.withMessageContaining(
@ -195,7 +277,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
.withMessageContaining(
@ -215,7 +297,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(clientCredentialsGrantRequest);
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read");
@ -232,7 +314,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(clientCredentialsGrantRequest);
assertThat(accessTokenResponse.getAccessToken().getScopes()).containsExactly("read", "write");
@ -241,7 +323,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test
public void getTokenResponseWhenTokenUriInvalidThenThrowOAuth2AuthorizationException() {
String invalidTokenUri = "https://invalid-provider.com/oauth2/token";
ClientRegistration clientRegistration = this.from(this.clientRegistration).tokenUri(invalidTokenUri).build();
ClientRegistration clientRegistration = this.clientRegistration.tokenUri(invalidTokenUri).build();
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
@ -264,7 +346,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
.withMessageContaining(
@ -280,7 +362,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
.withMessageContaining("[unauthorized_client]");
@ -290,7 +372,7 @@ public class DefaultClientCredentialsTokenResponseClientTests {
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
this.server.enqueue(new MockResponse().setResponseCode(500));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration);
this.clientRegistration.build());
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(clientCredentialsGrantRequest))
.withMessageContaining(
@ -301,16 +383,4 @@ public class DefaultClientCredentialsTokenResponseClientTests {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json);
}
private ClientRegistration.Builder from(ClientRegistration registration) {
// @formatter:off
return ClientRegistration.withRegistrationId(registration.getRegistrationId())
.clientId(registration.getClientId())
.clientSecret(registration.getClientSecret())
.clientAuthenticationMethod(registration.getClientAuthenticationMethod())
.authorizationGrantType(registration.getAuthorizationGrantType())
.scope(registration.getScopes())
.tokenUri(registration.getProviderDetails().getTokenUri());
// @formatter:on
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,8 +16,13 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -30,11 +35,12 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -47,9 +53,9 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/
public class DefaultPasswordTokenResponseClientTests {
private DefaultPasswordTokenResponseClient tokenResponseClient = new DefaultPasswordTokenResponseClient();
private DefaultPasswordTokenResponseClient tokenResponseClient;
private ClientRegistration.Builder clientRegistrationBuilder;
private ClientRegistration.Builder clientRegistration;
private String username = "user1";
@ -59,11 +65,15 @@ public class DefaultPasswordTokenResponseClientTests {
@Before
public void setup() throws Exception {
this.tokenResponseClient = new DefaultPasswordTokenResponseClient();
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
.authorizationGrantType(AuthorizationGrantType.PASSWORD).scope("read", "write").tokenUri(tokenUri);
// @formatter:off
this.clientRegistration = TestClientRegistrations.password()
.scope("read", "write")
.tokenUri(tokenUri);
// @formatter:on
}
@After
@ -97,7 +107,7 @@ public class DefaultPasswordTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
ClientRegistration clientRegistration = this.clientRegistration.build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest);
@ -121,7 +131,7 @@ public class DefaultPasswordTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
@ -130,7 +140,7 @@ public class DefaultPasswordTokenResponseClientTests {
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
@ -142,6 +152,83 @@ public class DefaultPasswordTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=client-secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.tokenResponseClient.getTokenResponse(passwordGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration,
this.username, this.password);
this.tokenResponseClient.getTokenResponse(passwordGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2PasswordGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
OAuth2PasswordGrantRequestEntityConverter requestEntityConverter = new OAuth2PasswordGrantRequestEntityConverter();
requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter);
this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -153,7 +240,7 @@ public class DefaultPasswordTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
this.clientRegistration.build(), this.username, this.password);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest))
.withMessageContaining(
@ -173,7 +260,7 @@ public class DefaultPasswordTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
this.clientRegistration.build(), this.username, this.password);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient.getTokenResponse(passwordGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
String formParameters = recordedRequest.getBody().readUtf8();
@ -186,7 +273,7 @@ public class DefaultPasswordTokenResponseClientTests {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
this.clientRegistration.build(), this.username, this.password);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest))
.withMessageContaining("[unauthorized_client]");
@ -196,7 +283,7 @@ public class DefaultPasswordTokenResponseClientTests {
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
this.server.enqueue(new MockResponse().setResponseCode(500));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password);
this.clientRegistration.build(), this.username, this.password);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(passwordGrantRequest))
.withMessageContaining("[invalid_token_response] An error occurred while attempting to "

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,14 @@
package org.springframework.security.oauth2.client.endpoint;
import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Collections;
import java.util.function.Function;
import javax.crypto.spec.SecretKeySpec;
import com.nimbusds.jose.jwk.JWK;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
@ -38,6 +43,8 @@ import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
@ -50,9 +57,9 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
*/
public class DefaultRefreshTokenTokenResponseClientTests {
private DefaultRefreshTokenTokenResponseClient tokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
private DefaultRefreshTokenTokenResponseClient tokenResponseClient;
private ClientRegistration.Builder clientRegistrationBuilder;
private ClientRegistration.Builder clientRegistration;
private OAuth2AccessToken accessToken;
@ -62,10 +69,11 @@ public class DefaultRefreshTokenTokenResponseClientTests {
@Before
public void setup() throws Exception {
this.tokenResponseClient = new DefaultRefreshTokenTokenResponseClient();
this.server = new MockWebServer();
this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().tokenUri(tokenUri);
this.clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(tokenUri);
this.accessToken = TestOAuth2AccessTokens.scopes("read", "write");
this.refreshToken = TestOAuth2RefreshTokens.refreshToken();
}
@ -102,7 +110,7 @@ public class DefaultRefreshTokenTokenResponseClientTests {
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.clientRegistration.build(), this.accessToken, this.refreshToken);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(refreshTokenGrantRequest);
Instant expiresAtAfter = Instant.now().plusSeconds(3600);
@ -124,11 +132,16 @@ public class DefaultRefreshTokenTokenResponseClientTests {
}
@Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n";
public void getTokenResponseWhenAuthenticationClientSecretPostThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_POST).build();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
@ -140,6 +153,83 @@ public class DefaultRefreshTokenTokenResponseClientTests {
assertThat(formParameters).contains("client_secret=client-secret");
}
@Test
public void getTokenResponseWhenAuthenticationClientSecretJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.clientSecret(TestKeys.DEFAULT_ENCODED_SECRET_KEY)
.build();
// @formatter:on
// Configure Jwt client authentication converter
SecretKeySpec secretKey = new SecretKeySpec(
clientRegistration.getClientSecret().getBytes(StandardCharsets.UTF_8), "HmacSHA256");
JWK jwk = TestJwks.jwk(secretKey).build();
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
@Test
public void getTokenResponseWhenAuthenticationPrivateKeyJwtThenFormParametersAreSent() throws Exception {
// @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ClientRegistration clientRegistration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
// Configure Jwt client authentication converter
JWK jwk = TestJwks.DEFAULT_RSA_JWK;
Function<ClientRegistration, JWK> jwkResolver = (registration) -> jwk;
configureJwtClientAuthenticationConverter(jwkResolver);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
this.accessToken, this.refreshToken);
this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getHeader(HttpHeaders.AUTHORIZATION)).isNull();
String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters)
.contains("client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer");
assertThat(formParameters).contains("client_assertion=");
}
private void configureJwtClientAuthenticationConverter(Function<ClientRegistration, JWK> jwkResolver) {
NimbusJwtClientAuthenticationParametersConverter<OAuth2RefreshTokenGrantRequest> jwtClientAuthenticationConverter = new NimbusJwtClientAuthenticationParametersConverter<>(
jwkResolver);
OAuth2RefreshTokenGrantRequestEntityConverter requestEntityConverter = new OAuth2RefreshTokenGrantRequestEntityConverter();
requestEntityConverter.addParametersConverter(jwtClientAuthenticationConverter);
this.tokenResponseClient.setRequestEntityConverter(requestEntityConverter);
}
@Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
// @formatter:off
@ -151,7 +241,7 @@ public class DefaultRefreshTokenTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.clientRegistration.build(), this.accessToken, this.refreshToken);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest))
.withMessageContaining("[invalid_token_response] An error occurred while attempting to "
@ -171,8 +261,7 @@ public class DefaultRefreshTokenTokenResponseClientTests {
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken,
Collections.singleton("read"));
this.clientRegistration.build(), this.accessToken, this.refreshToken, Collections.singleton("read"));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(refreshTokenGrantRequest);
RecordedRequest recordedRequest = this.server.takeRequest();
@ -186,7 +275,7 @@ public class DefaultRefreshTokenTokenResponseClientTests {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.clientRegistration.build(), this.accessToken, this.refreshToken);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest))
.withMessageContaining("[unauthorized_client]");
@ -196,7 +285,7 @@ public class DefaultRefreshTokenTokenResponseClientTests {
public void getTokenResponseWhenServerErrorResponseThenThrowOAuth2AuthorizationException() {
this.server.enqueue(new MockResponse().setResponseCode(500));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
this.clientRegistration.build(), this.accessToken, this.refreshToken);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(refreshTokenGrantRequest))
.withMessageContaining("[invalid_token_response] An error occurred while attempting to "

View File

@ -0,0 +1,123 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.junit.Test;
import org.springframework.security.oauth2.jose.JwaAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* Tests for {@link JoseHeader}.
*
* @author Joe Grandja
*/
public class JoseHeaderTests {
@Test
public void withAlgorithmWhenNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.withAlgorithm(null))
.isInstanceOf(IllegalArgumentException.class).withMessage("jwaAlgorithm cannot be null");
}
@Test
public void buildWhenAllHeadersProvidedThenAllHeadersAreSet() {
JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build();
// @formatter:off
JoseHeader joseHeader = JoseHeader.withAlgorithm(expectedJoseHeader.getAlgorithm())
.jwkSetUrl(expectedJoseHeader.getJwkSetUrl().toExternalForm())
.jwk(expectedJoseHeader.getJwk())
.keyId(expectedJoseHeader.getKeyId())
.x509Url(expectedJoseHeader.getX509Url().toExternalForm())
.x509CertificateChain(expectedJoseHeader.getX509CertificateChain())
.x509SHA1Thumbprint(expectedJoseHeader.getX509SHA1Thumbprint())
.x509SHA256Thumbprint(expectedJoseHeader.getX509SHA256Thumbprint())
.type(expectedJoseHeader.getType())
.contentType(expectedJoseHeader.getContentType())
.headers((headers) -> headers.put("custom-header-name", "custom-header-value"))
.build();
// @formatter:on
assertThat(joseHeader.<JwaAlgorithm>getAlgorithm()).isEqualTo(expectedJoseHeader.getAlgorithm());
assertThat(joseHeader.getJwkSetUrl()).isEqualTo(expectedJoseHeader.getJwkSetUrl());
assertThat(joseHeader.getJwk()).isEqualTo(expectedJoseHeader.getJwk());
assertThat(joseHeader.getKeyId()).isEqualTo(expectedJoseHeader.getKeyId());
assertThat(joseHeader.getX509Url()).isEqualTo(expectedJoseHeader.getX509Url());
assertThat(joseHeader.getX509CertificateChain()).isEqualTo(expectedJoseHeader.getX509CertificateChain());
assertThat(joseHeader.getX509SHA1Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA1Thumbprint());
assertThat(joseHeader.getX509SHA256Thumbprint()).isEqualTo(expectedJoseHeader.getX509SHA256Thumbprint());
assertThat(joseHeader.getType()).isEqualTo(expectedJoseHeader.getType());
assertThat(joseHeader.getContentType()).isEqualTo(expectedJoseHeader.getContentType());
assertThat(joseHeader.<String>getHeader("custom-header-name")).isEqualTo("custom-header-value");
assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders());
}
@Test
public void fromWhenNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JoseHeader.from(null))
.isInstanceOf(IllegalArgumentException.class).withMessage("headers cannot be null");
}
@Test
public void fromWhenHeadersProvidedThenCopied() {
JoseHeader expectedJoseHeader = TestJoseHeaders.joseHeader().build();
JoseHeader joseHeader = JoseHeader.from(expectedJoseHeader).build();
assertThat(joseHeader.getHeaders()).isEqualTo(expectedJoseHeader.getHeaders());
}
@Test
public void headerWhenNameNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header(null, "value"))
.withMessage("name cannot be empty");
}
@Test
public void headerWhenValueNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).header("name", null))
.withMessage("value cannot be null");
}
@Test
public void getHeaderWhenNullThenThrowIllegalArgumentException() {
JoseHeader joseHeader = TestJoseHeaders.joseHeader().build();
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> joseHeader.getHeader(null))
.isInstanceOf(IllegalArgumentException.class).withMessage("name cannot be empty");
}
}

View File

@ -0,0 +1,105 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import org.junit.Test;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* Tests for {@link JwtClaimsSet}.
*
* @author Joe Grandja
*/
public class JwtClaimsSetTests {
@Test
public void buildWhenClaimsEmptyThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.builder().build())
.isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be empty");
}
@Test
public void buildWhenAllClaimsProvidedThenAllClaimsAreSet() {
JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
// @formatter:off
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.builder()
.issuer(expectedJwtClaimsSet.getIssuer().toExternalForm())
.subject(expectedJwtClaimsSet.getSubject())
.audience(expectedJwtClaimsSet.getAudience())
.issuedAt(expectedJwtClaimsSet.getIssuedAt())
.notBefore(expectedJwtClaimsSet.getNotBefore())
.expiresAt(expectedJwtClaimsSet.getExpiresAt())
.id(expectedJwtClaimsSet.getId())
.claims((claims) -> claims.put("custom-claim-name", "custom-claim-value"))
.build();
// @formatter:on
assertThat(jwtClaimsSet.getIssuer()).isEqualTo(expectedJwtClaimsSet.getIssuer());
assertThat(jwtClaimsSet.getSubject()).isEqualTo(expectedJwtClaimsSet.getSubject());
assertThat(jwtClaimsSet.getAudience()).isEqualTo(expectedJwtClaimsSet.getAudience());
assertThat(jwtClaimsSet.getIssuedAt()).isEqualTo(expectedJwtClaimsSet.getIssuedAt());
assertThat(jwtClaimsSet.getNotBefore()).isEqualTo(expectedJwtClaimsSet.getNotBefore());
assertThat(jwtClaimsSet.getExpiresAt()).isEqualTo(expectedJwtClaimsSet.getExpiresAt());
assertThat(jwtClaimsSet.getId()).isEqualTo(expectedJwtClaimsSet.getId());
assertThat(jwtClaimsSet.<String>getClaim("custom-claim-name")).isEqualTo("custom-claim-value");
assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims());
}
@Test
public void fromWhenNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(() -> JwtClaimsSet.from(null))
.isInstanceOf(IllegalArgumentException.class).withMessage("claims cannot be null");
}
@Test
public void fromWhenClaimsProvidedThenCopied() {
JwtClaimsSet expectedJwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
JwtClaimsSet jwtClaimsSet = JwtClaimsSet.from(expectedJwtClaimsSet).build();
assertThat(jwtClaimsSet.getClaims()).isEqualTo(expectedJwtClaimsSet.getClaims());
}
@Test
public void claimWhenNameNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> JwtClaimsSet.builder().claim(null, "value")).withMessage("name cannot be empty");
}
@Test
public void claimWhenValueNullThenThrowIllegalArgumentException() {
assertThatExceptionOfType(IllegalArgumentException.class)
.isThrownBy(() -> JwtClaimsSet.builder().claim("name", null)).withMessage("value cannot be null");
}
}

View File

@ -0,0 +1,347 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import com.nimbusds.jose.KeySourceException;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.JWKSelector;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.KeyUse;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jose.jwk.source.JWKSource;
import com.nimbusds.jose.proc.SecurityContext;
import com.nimbusds.jose.util.Base64URL;
import org.junit.Before;
import org.junit.Test;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.TestKeys;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.willAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* Tests for {@link NimbusJwsEncoder}.
*
* @author Joe Grandja
*/
public class NimbusJwsEncoderTests {
private List<JWK> jwkList;
private JWKSource<SecurityContext> jwkSource;
private NimbusJwsEncoder jwsEncoder;
@Before
public void setUp() {
this.jwkList = new ArrayList<>();
this.jwkSource = (jwkSelector, securityContext) -> jwkSelector.select(new JWKSet(this.jwkList));
this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource);
}
@Test
public void constructorWhenJwkSourceNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> new NimbusJwsEncoder(null))
.withMessage("jwkSource cannot be null");
}
@Test
public void encodeWhenHeadersNullThenThrowIllegalArgumentException() {
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(null, jwtClaimsSet))
.withMessage("headers cannot be null");
}
@Test
public void encodeWhenClaimsNullThenThrowIllegalArgumentException() {
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
assertThatIllegalArgumentException().isThrownBy(() -> this.jwsEncoder.encode(joseHeader, null))
.withMessage("claims cannot be null");
}
@Test
public void encodeWhenJwkSelectFailedThenThrowJwtEncodingException() throws Exception {
this.jwkSource = mock(JWKSource.class);
this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource);
given(this.jwkSource.get(any(), any())).willThrow(new KeySourceException("key source error"));
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet))
.withMessageContaining("Failed to select a JWK signing key -> key source error");
}
@Test
public void encodeWhenJwkMultipleSelectedThenThrowJwtEncodingException() throws Exception {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
this.jwkList.add(rsaJwk);
this.jwkList.add(rsaJwk);
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet))
.withMessageContaining("Found multiple JWK signing keys for algorithm 'RS256'");
}
@Test
public void encodeWhenJwkSelectEmptyThenThrowJwtEncodingException() {
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet))
.withMessageContaining("Failed to select a JWK signing key");
}
@Test
public void encodeWhenJwkSelectWithProvidedKidThenSelected() {
// @formatter:off
RSAKey rsaJwk1 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.keyID("rsa-jwk-1")
.build();
this.jwkList.add(rsaJwk1);
RSAKey rsaJwk2 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.keyID("rsa-jwk-2")
.build();
this.jwkList.add(rsaJwk2);
// @formatter:on
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).keyId(rsaJwk2.getKeyID()).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk2.getKeyID());
}
@Test
public void encodeWhenJwkSelectWithProvidedX5TS256ThenSelected() {
// @formatter:off
RSAKey rsaJwk1 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-1"))
.keyID(null)
.build();
this.jwkList.add(rsaJwk1);
RSAKey rsaJwk2 = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-2"))
.keyID(null)
.build();
this.jwkList.add(rsaJwk2);
// @formatter:on
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256)
.x509SHA256Thumbprint(rsaJwk1.getX509CertSHA256Thumbprint().toString()).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256))
.isEqualTo(rsaJwk1.getX509CertSHA256Thumbprint().toString());
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isNull();
}
@Test
public void encodeWhenJwkUseEncryptionThenThrowJwtEncodingException() throws Exception {
// @formatter:off
RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.keyUse(KeyUse.ENCRYPTION)
.build();
// @formatter:on
this.jwkSource = mock(JWKSource.class);
this.jwsEncoder = new NimbusJwsEncoder(this.jwkSource);
given(this.jwkSource.get(any(), any())).willReturn(Collections.singletonList(rsaJwk));
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
assertThatExceptionOfType(JwtEncodingException.class)
.isThrownBy(() -> this.jwsEncoder.encode(joseHeader, jwtClaimsSet)).withMessageContaining(
"Failed to create a JWS Signer -> The JWK use must be sig (signature) or unspecified");
}
@Test
public void encodeWhenSuccessThenDecodes() throws Exception {
// @formatter:off
RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.keyID("rsa-jwk-1")
.x509CertSHA256Thumbprint(new Base64URL("x509CertSHA256Thumbprint-1"))
.build();
this.jwkList.add(rsaJwk);
// @formatter:on
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
Jwt encodedJws = this.jwsEncoder.encode(joseHeader, jwtClaimsSet);
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(joseHeader.getAlgorithm());
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JKU)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.JWK)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID());
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5U)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5C)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.X5T_S256))
.isEqualTo(rsaJwk.getX509CertSHA256Thumbprint().toString());
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.TYP)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CTY)).isNull();
assertThat(encodedJws.getHeaders().get(JoseHeaderNames.CRIT)).isNull();
assertThat(encodedJws.getIssuer()).isEqualTo(jwtClaimsSet.getIssuer());
assertThat(encodedJws.getSubject()).isEqualTo(jwtClaimsSet.getSubject());
assertThat(encodedJws.getAudience()).isEqualTo(jwtClaimsSet.getAudience());
assertThat(encodedJws.getExpiresAt()).isEqualTo(jwtClaimsSet.getExpiresAt());
assertThat(encodedJws.getNotBefore()).isEqualTo(jwtClaimsSet.getNotBefore());
assertThat(encodedJws.getIssuedAt()).isEqualTo(jwtClaimsSet.getIssuedAt());
assertThat(encodedJws.getId()).isEqualTo(jwtClaimsSet.getId());
assertThat(encodedJws.<String>getClaim("custom-claim-name")).isEqualTo("custom-claim-value");
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws.getTokenValue());
}
@Test
public void encodeWhenKeysRotatedThenNewKeyUsed() throws Exception {
TestJWKSource jwkSource = new TestJWKSource();
JWKSource<SecurityContext> jwkSourceDelegate = spy(new JWKSource<SecurityContext>() {
@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) {
return jwkSource.get(jwkSelector, context);
}
});
NimbusJwsEncoder jwsEncoder = new NimbusJwsEncoder(jwkSourceDelegate);
JwkListResultCaptor jwkListResultCaptor = new JwkListResultCaptor();
willAnswer(jwkListResultCaptor).given(jwkSourceDelegate).get(any(), any());
JoseHeader joseHeader = JoseHeader.withAlgorithm(SignatureAlgorithm.RS256).build();
JwtClaimsSet jwtClaimsSet = TestJwtClaimsSets.jwtClaimsSet().build();
Jwt encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
JWK jwk1 = jwkListResultCaptor.getResult().get(0);
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk1).toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws.getTokenValue());
jwkSource.rotate(); // Simulate key rotation
encodedJws = jwsEncoder.encode(joseHeader, jwtClaimsSet);
JWK jwk2 = jwkListResultCaptor.getResult().get(0);
jwtDecoder = NimbusJwtDecoder.withPublicKey(((RSAKey) jwk2).toRSAPublicKey()).build();
jwtDecoder.decode(encodedJws.getTokenValue());
assertThat(jwk1.getKeyID()).isNotEqualTo(jwk2.getKeyID());
}
private static final class JwkListResultCaptor implements Answer<List<JWK>> {
private List<JWK> result;
private List<JWK> getResult() {
return this.result;
}
@SuppressWarnings("unchecked")
@Override
public List<JWK> answer(InvocationOnMock invocationOnMock) throws Throwable {
this.result = (List<JWK>) invocationOnMock.callRealMethod();
return this.result;
}
}
private static final class TestJWKSource implements JWKSource<SecurityContext> {
private int keyId = 1000;
private JWKSet jwkSet;
private TestJWKSource() {
init();
}
@Override
public List<JWK> get(JWKSelector jwkSelector, SecurityContext context) {
return jwkSelector.select(this.jwkSet);
}
private void init() {
// @formatter:off
RSAKey rsaJwk = TestJwks.jwk(TestKeys.DEFAULT_PUBLIC_KEY, TestKeys.DEFAULT_PRIVATE_KEY)
.keyID("rsa-jwk-" + this.keyId++)
.build();
ECKey ecJwk = TestJwks.jwk((ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(), (ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate())
.keyID("ec-jwk-" + this.keyId++)
.build();
OctetSequenceKey secretJwk = TestJwks.jwk(TestKeys.DEFAULT_SECRET_KEY)
.keyID("secret-jwk-" + this.keyId++)
.build();
// @formatter:on
this.jwkSet = new JWKSet(Arrays.asList(rsaJwk, ecJwk, secretJwk));
}
private void rotate() {
init();
}
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.util.Collections;
import java.util.function.Function;
import com.nimbusds.jose.jwk.JWK;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
import org.junit.Before;
import org.junit.Test;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.OAuth2AuthorizationException;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.jose.TestJwks;
import org.springframework.security.oauth2.jose.jws.MacAlgorithm;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
import org.springframework.security.oauth2.jwt.Jwt;
import org.springframework.security.oauth2.jwt.JwtClaimNames;
import org.springframework.security.oauth2.jwt.NimbusJwtDecoder;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verifyNoInteractions;
/**
* Tests for {@link NimbusJwtClientAuthenticationParametersConverter}.
*
* @author Joe Grandja
*/
public class NimbusJwtClientAuthenticationParametersConverterTests {
private Function<ClientRegistration, JWK> jwkResolver;
private NimbusJwtClientAuthenticationParametersConverter<OAuth2ClientCredentialsGrantRequest> converter;
@Before
public void setup() {
this.jwkResolver = mock(Function.class);
this.converter = new NimbusJwtClientAuthenticationParametersConverter<>(this.jwkResolver);
}
@Test
public void constructorWhenJwkResolverNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException()
.isThrownBy(() -> new NimbusJwtClientAuthenticationParametersConverter<>(null))
.withMessage("jwkResolver cannot be null");
}
@Test
public void convertWhenAuthorizationGrantRequestNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.convert(null))
.withMessage("authorizationGrantRequest cannot be null");
}
@Test
public void convertWhenOtherClientAuthenticationMethodThenNotCustomized() {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.build();
// @formatter:on
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
assertThat(this.converter.convert(clientCredentialsGrantRequest)).isNull();
verifyNoInteractions(this.jwkResolver);
}
@Test
public void convertWhenJwkNotResolvedThenThrowOAuth2AuthorizationException() {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.converter.convert(clientCredentialsGrantRequest))
.withMessage("[invalid_key] Failed to resolve JWK signing key for client registration '"
+ clientRegistration.getRegistrationId() + "'.");
}
@Test
public void convertWhenPrivateKeyJwtClientAuthenticationMethodThenCustomized() throws Exception {
RSAKey rsaJwk = TestJwks.DEFAULT_RSA_JWK;
given(this.jwkResolver.apply(any())).willReturn(rsaJwk);
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.PRIVATE_KEY_JWT)
.build();
// @formatter:on
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE))
.isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
assertThat(encodedJws).isNotNull();
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withPublicKey(rsaJwk.toRSAPublicKey()).build();
Jwt jws = jwtDecoder.decode(encodedJws);
assertThat(jws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(SignatureAlgorithm.RS256.getName());
assertThat(jws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(rsaJwk.getKeyID());
assertThat(jws.<String>getClaim(JwtClaimNames.ISS)).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getSubject()).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getAudience())
.isEqualTo(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri()));
assertThat(jws.getId()).isNotNull();
assertThat(jws.getIssuedAt()).isNotNull();
assertThat(jws.getExpiresAt()).isNotNull();
}
@Test
public void convertWhenClientSecretJwtClientAuthenticationMethodThenCustomized() {
OctetSequenceKey secretJwk = TestJwks.DEFAULT_SECRET_JWK;
given(this.jwkResolver.apply(any())).willReturn(secretJwk);
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials()
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_JWT)
.build();
// @formatter:on
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
MultiValueMap<String, String> parameters = this.converter.convert(clientCredentialsGrantRequest);
assertThat(parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION_TYPE))
.isEqualTo("urn:ietf:params:oauth:client-assertion-type:jwt-bearer");
String encodedJws = parameters.getFirst(OAuth2ParameterNames.CLIENT_ASSERTION);
assertThat(encodedJws).isNotNull();
NimbusJwtDecoder jwtDecoder = NimbusJwtDecoder.withSecretKey(secretJwk.toSecretKey()).build();
Jwt jws = jwtDecoder.decode(encodedJws);
assertThat(jws.getHeaders().get(JoseHeaderNames.ALG)).isEqualTo(MacAlgorithm.HS256.getName());
assertThat(jws.getHeaders().get(JoseHeaderNames.KID)).isEqualTo(secretJwk.getKeyID());
assertThat(jws.<String>getClaim(JwtClaimNames.ISS)).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getSubject()).isEqualTo(clientRegistration.getClientId());
assertThat(jws.getAudience())
.isEqualTo(Collections.singletonList(clientRegistration.getProviderDetails().getTokenUri()));
assertThat(jws.getId()).isNotNull();
assertThat(jws.getIssuedAt()).isNotNull();
assertThat(jws.getExpiresAt()).isNotNull();
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,28 +16,36 @@
package org.springframework.security.oauth2.client.endpoint;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InOrder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.oauth2.core.endpoint.PkceParameterNames;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationExchanges;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests;
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationResponses;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2AuthorizationCodeGrantRequestEntityConverter}.
@ -46,49 +54,78 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
private OAuth2AuthorizationCodeGrantRequestEntityConverter converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter();
private OAuth2AuthorizationCodeGrantRequestEntityConverter converter;
// @formatter:off
private ClientRegistration.Builder clientRegistrationBuilder = ClientRegistration
.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("https://client.com/callback/client-1")
.scope("read", "write")
.authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/user")
.userNameAttributeName("id")
.clientName("client-1");
// @formatter:on
@Before
public void setup() {
this.converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter();
}
// @formatter:off
private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder = OAuth2AuthorizationRequest
.authorizationCode()
.clientId("client-1")
.state("state-1234")
.authorizationUri("https://provider.com/oauth2/authorize")
.redirectUri("https://client.com/callback/client-1")
.scopes(new HashSet(Arrays.asList("read", "write")));
// @formatter:on
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
// @formatter:off
private OAuth2AuthorizationResponse.Builder authorizationResponseBuilder = OAuth2AuthorizationResponse
.success("code-1234")
.state("state-1234")
.redirectUri("https://client.com/callback/client-1");
// @formatter:on
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() {
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
this.converter.setHeadersConverter(headersConverter1);
Converter<OAuth2AuthorizationCodeGrantRequest, HttpHeaders> headersConverter2 = mock(Converter.class);
this.converter.addHeadersConverter(headersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success();
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest(
clientRegistration, authorizationExchange);
this.converter.convert(authorizationCodeGrantRequest);
InOrder inOrder = inOrder(headersConverter1, headersConverter2);
inOrder.verify(headersConverter1).convert(any(OAuth2AuthorizationCodeGrantRequest.class));
inOrder.verify(headersConverter2).convert(any(OAuth2AuthorizationCodeGrantRequest.class));
}
@Test
public void convertWhenParametersConverterSetThenCalled() {
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter1 = mock(
Converter.class);
this.converter.setParametersConverter(parametersConverter1);
Converter<OAuth2AuthorizationCodeGrantRequest, MultiValueMap<String, String>> parametersConverter2 = mock(
Converter.class);
this.converter.addParametersConverter(parametersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success();
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest(
clientRegistration, authorizationExchange);
this.converter.convert(authorizationCodeGrantRequest);
InOrder inOrder = inOrder(parametersConverter1, parametersConverter2);
inOrder.verify(parametersConverter1).convert(any(OAuth2AuthorizationCodeGrantRequest.class));
inOrder.verify(parametersConverter2).convert(any(OAuth2AuthorizationCodeGrantRequest.class));
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenGrantRequestValidThenConverts() {
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.build();
OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse);
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AuthorizationExchange authorizationExchange = TestOAuth2AuthorizationExchanges.success();
OAuth2AuthorizationRequest authorizationRequest = authorizationExchange.getAuthorizationRequest();
OAuth2AuthorizationResponse authorizationResponse = authorizationExchange.getAuthorizationResponse();
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest(
clientRegistration, authorizationExchange);
RequestEntity<?> requestEntity = this.converter.convert(authorizationCodeGrantRequest);
@ -103,25 +140,25 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE))
.isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234");
assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo(authorizationResponse.getCode());
assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isNull();
assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
.isEqualTo(clientRegistration.getRedirectUri());
.isEqualTo(authorizationRequest.getRedirectUri());
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenPkceGrantRequestValidThenConverts() {
ClientRegistration clientRegistration = this.clientRegistrationBuilder.clientAuthenticationMethod(null)
.clientSecret(null).build();
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.clientAuthenticationMethod(null).clientSecret(null).build();
Map<String, Object> attributes = new HashMap<>();
attributes.put(PkceParameterNames.CODE_VERIFIER, "code-verifier-1234");
Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234");
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.attributes(attributes)
.additionalParameters(additionalParameters).build();
OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBuilder.build();
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.attributes(attributes).additionalParameters(additionalParameters).build();
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().build();
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse);
OAuth2AuthorizationCodeGrantRequest authorizationCodeGrantRequest = new OAuth2AuthorizationCodeGrantRequest(
@ -138,11 +175,13 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE))
.isEqualTo(AuthorizationGrantType.AUTHORIZATION_CODE.getValue());
assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo("code-1234");
assertThat(formParameters.getFirst(OAuth2ParameterNames.CODE)).isEqualTo(authorizationResponse.getCode());
assertThat(formParameters.getFirst(OAuth2ParameterNames.REDIRECT_URI))
.isEqualTo(clientRegistration.getRedirectUri());
assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID)).isEqualTo("client-1");
assertThat(formParameters.getFirst(PkceParameterNames.CODE_VERIFIER)).isEqualTo("code-verifier-1234");
.isEqualTo(authorizationRequest.getRedirectUri());
assertThat(formParameters.getFirst(OAuth2ParameterNames.CLIENT_ID))
.isEqualTo(authorizationRequest.getClientId());
assertThat(formParameters.getFirst(PkceParameterNames.CODE_VERIFIER))
.isEqualTo(authorizationRequest.getAttribute(PkceParameterNames.CODE_VERIFIER));
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,18 +18,24 @@ package org.springframework.security.oauth2.client.endpoint;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InOrder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2ClientCredentialsGrantRequestEntityConverter}.
@ -38,30 +44,76 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OAuth2ClientCredentialsGrantRequestEntityConverterTests {
private OAuth2ClientCredentialsGrantRequestEntityConverter converter = new OAuth2ClientCredentialsGrantRequestEntityConverter();
private OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest;
private OAuth2ClientCredentialsGrantRequestEntityConverter converter;
@Before
public void setup() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.CLIENT_SECRET_BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri("https://provider.com/oauth2/token")
.build();
// @formatter:on
this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
this.converter = new OAuth2ClientCredentialsGrantRequestEntityConverter();
}
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() {
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
this.converter.setHeadersConverter(headersConverter1);
Converter<OAuth2ClientCredentialsGrantRequest, HttpHeaders> headersConverter2 = mock(Converter.class);
this.converter.addHeadersConverter(headersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
this.converter.convert(clientCredentialsGrantRequest);
InOrder inOrder = inOrder(headersConverter1, headersConverter2);
inOrder.verify(headersConverter1).convert(any(OAuth2ClientCredentialsGrantRequest.class));
inOrder.verify(headersConverter2).convert(any(OAuth2ClientCredentialsGrantRequest.class));
}
@Test
public void convertWhenParametersConverterSetThenCalled() {
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter1 = mock(
Converter.class);
this.converter.setParametersConverter(parametersConverter1);
Converter<OAuth2ClientCredentialsGrantRequest, MultiValueMap<String, String>> parametersConverter2 = mock(
Converter.class);
this.converter.addParametersConverter(parametersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
this.converter.convert(clientCredentialsGrantRequest);
InOrder inOrder = inOrder(parametersConverter1, parametersConverter2);
inOrder.verify(parametersConverter1).convert(any(OAuth2ClientCredentialsGrantRequest.class));
inOrder.verify(parametersConverter2).convert(any(OAuth2ClientCredentialsGrantRequest.class));
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenGrantRequestValidThenConverts() {
RequestEntity<?> requestEntity = this.converter.convert(this.clientCredentialsGrantRequest);
ClientRegistration clientRegistration = this.clientCredentialsGrantRequest.getClientRegistration();
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
clientRegistration);
RequestEntity<?> requestEntity = this.converter.convert(clientCredentialsGrantRequest);
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
assertThat(requestEntity.getUrl().toASCIIString())
.isEqualTo(clientRegistration.getProviderDetails().getTokenUri());
@ -73,7 +125,7 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests {
MultiValueMap<String, String> formParameters = (MultiValueMap<String, String>) requestEntity.getBody();
assertThat(formParameters.getFirst(OAuth2ParameterNames.GRANT_TYPE))
.isEqualTo(AuthorizationGrantType.CLIENT_CREDENTIALS.getValue());
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write");
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -18,7 +18,9 @@ package org.springframework.security.oauth2.client.endpoint;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InOrder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -30,6 +32,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2PasswordGrantRequestEntityConverter}.
@ -38,26 +44,76 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OAuth2PasswordGrantRequestEntityConverterTests {
private OAuth2PasswordGrantRequestEntityConverter converter = new OAuth2PasswordGrantRequestEntityConverter();
private OAuth2PasswordGrantRequest passwordGrantRequest;
private OAuth2PasswordGrantRequestEntityConverter converter;
@Before
public void setup() {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.scope("read", "write")
.build();
// @formatter:on
this.passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password");
this.converter = new OAuth2PasswordGrantRequestEntityConverter();
}
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() {
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
this.converter.setHeadersConverter(headersConverter1);
Converter<OAuth2PasswordGrantRequest, HttpHeaders> headersConverter2 = mock(Converter.class);
this.converter.addHeadersConverter(headersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.password().build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1",
"password");
this.converter.convert(passwordGrantRequest);
InOrder inOrder = inOrder(headersConverter1, headersConverter2);
inOrder.verify(headersConverter1).convert(any(OAuth2PasswordGrantRequest.class));
inOrder.verify(headersConverter2).convert(any(OAuth2PasswordGrantRequest.class));
}
@Test
public void convertWhenParametersConverterSetThenCalled() {
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> parametersConverter1 = mock(
Converter.class);
this.converter.setParametersConverter(parametersConverter1);
Converter<OAuth2PasswordGrantRequest, MultiValueMap<String, String>> parametersConverter2 = mock(
Converter.class);
this.converter.addParametersConverter(parametersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.password().build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1",
"password");
this.converter.convert(passwordGrantRequest);
InOrder inOrder = inOrder(parametersConverter1, parametersConverter2);
inOrder.verify(parametersConverter1).convert(any(OAuth2PasswordGrantRequest.class));
inOrder.verify(parametersConverter2).convert(any(OAuth2PasswordGrantRequest.class));
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenGrantRequestValidThenConverts() {
RequestEntity<?> requestEntity = this.converter.convert(this.passwordGrantRequest);
ClientRegistration clientRegistration = this.passwordGrantRequest.getClientRegistration();
ClientRegistration clientRegistration = TestClientRegistrations.password().build();
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1",
"password");
RequestEntity<?> requestEntity = this.converter.convert(passwordGrantRequest);
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
assertThat(requestEntity.getUrl().toASCIIString())
.isEqualTo(clientRegistration.getProviderDetails().getTokenUri());
@ -71,7 +127,7 @@ public class OAuth2PasswordGrantRequestEntityConverterTests {
.isEqualTo(AuthorizationGrantType.PASSWORD.getValue());
assertThat(formParameters.getFirst(OAuth2ParameterNames.USERNAME)).isEqualTo("user1");
assertThat(formParameters.getFirst(OAuth2ParameterNames.PASSWORD)).isEqualTo("password");
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).isEqualTo("read write");
assertThat(formParameters.getFirst(OAuth2ParameterNames.SCOPE)).contains(clientRegistration.getScopes());
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,7 +20,9 @@ import java.util.Collections;
import org.junit.Before;
import org.junit.Test;
import org.mockito.InOrder;
import org.springframework.core.convert.converter.Converter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
@ -28,6 +30,7 @@ import org.springframework.http.RequestEntity;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2AccessToken;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.TestOAuth2AccessTokens;
import org.springframework.security.oauth2.core.TestOAuth2RefreshTokens;
@ -35,6 +38,10 @@ import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.util.MultiValueMap;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link OAuth2RefreshTokenGrantRequestEntityConverter}.
@ -43,23 +50,82 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class OAuth2RefreshTokenGrantRequestEntityConverterTests {
private OAuth2RefreshTokenGrantRequestEntityConverter converter = new OAuth2RefreshTokenGrantRequestEntityConverter();
private OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest;
private OAuth2RefreshTokenGrantRequestEntityConverter converter;
@Before
public void setup() {
this.refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
TestClientRegistrations.clientRegistration().build(), TestOAuth2AccessTokens.scopes("read", "write"),
TestOAuth2RefreshTokens.refreshToken(), Collections.singleton("read"));
this.converter = new OAuth2RefreshTokenGrantRequestEntityConverter();
}
@Test
public void setHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void addHeadersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addHeadersConverter(null))
.withMessage("headersConverter cannot be null");
}
@Test
public void setParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.setParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void addParametersConverterWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.converter.addParametersConverter(null))
.withMessage("parametersConverter cannot be null");
}
@Test
public void convertWhenHeadersConverterSetThenCalled() {
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter1 = mock(Converter.class);
this.converter.setHeadersConverter(headersConverter1);
Converter<OAuth2RefreshTokenGrantRequest, HttpHeaders> headersConverter2 = mock(Converter.class);
this.converter.addHeadersConverter(headersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write");
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
accessToken, refreshToken);
this.converter.convert(refreshTokenGrantRequest);
InOrder inOrder = inOrder(headersConverter1, headersConverter2);
inOrder.verify(headersConverter1).convert(any(OAuth2RefreshTokenGrantRequest.class));
inOrder.verify(headersConverter2).convert(any(OAuth2RefreshTokenGrantRequest.class));
}
@Test
public void convertWhenParametersConverterSetThenCalled() {
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter1 = mock(
Converter.class);
this.converter.setParametersConverter(parametersConverter1);
Converter<OAuth2RefreshTokenGrantRequest, MultiValueMap<String, String>> parametersConverter2 = mock(
Converter.class);
this.converter.addParametersConverter(parametersConverter2);
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write");
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
accessToken, refreshToken);
this.converter.convert(refreshTokenGrantRequest);
InOrder inOrder = inOrder(parametersConverter1, parametersConverter2);
inOrder.verify(parametersConverter1).convert(any(OAuth2RefreshTokenGrantRequest.class));
inOrder.verify(parametersConverter2).convert(any(OAuth2RefreshTokenGrantRequest.class));
}
@SuppressWarnings("unchecked")
@Test
public void convertWhenGrantRequestValidThenConverts() {
RequestEntity<?> requestEntity = this.converter.convert(this.refreshTokenGrantRequest);
ClientRegistration clientRegistration = this.refreshTokenGrantRequest.getClientRegistration();
OAuth2RefreshToken refreshToken = this.refreshTokenGrantRequest.getRefreshToken();
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
OAuth2AccessToken accessToken = TestOAuth2AccessTokens.scopes("read", "write");
OAuth2RefreshToken refreshToken = TestOAuth2RefreshTokens.refreshToken();
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(clientRegistration,
accessToken, refreshToken, Collections.singleton("read"));
RequestEntity<?> requestEntity = this.converter.convert(refreshTokenGrantRequest);
assertThat(requestEntity.getMethod()).isEqualTo(HttpMethod.POST);
assertThat(requestEntity.getUrl().toASCIIString())
.isEqualTo(clientRegistration.getProviderDetails().getTokenUri());

View File

@ -0,0 +1,76 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.springframework.security.oauth2.jose.jws.SignatureAlgorithm;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* @author Joe Grandja
*/
final class TestJoseHeaders {
private TestJoseHeaders() {
}
static JoseHeader.Builder joseHeader() {
return joseHeader(SignatureAlgorithm.RS256);
}
static JoseHeader.Builder joseHeader(SignatureAlgorithm signatureAlgorithm) {
// @formatter:off
return JoseHeader.withAlgorithm(signatureAlgorithm)
.jwkSetUrl("https://provider.com/oauth2/jwks")
.jwk(rsaJwk())
.keyId("keyId")
.x509Url("https://provider.com/oauth2/x509")
.x509CertificateChain(Arrays.asList("x509Cert1", "x509Cert2"))
.x509SHA1Thumbprint("x509SHA1Thumbprint")
.x509SHA256Thumbprint("x509SHA256Thumbprint")
.type("JWT")
.contentType("jwt-content-type")
.header("custom-header-name", "custom-header-value");
// @formatter:on
}
private static Map<String, Object> rsaJwk() {
Map<String, Object> rsaJwk = new HashMap<>();
rsaJwk.put("kty", "RSA");
rsaJwk.put("n", "modulus");
rsaJwk.put("e", "exponent");
return rsaJwk;
}
}

View File

@ -0,0 +1,64 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.endpoint;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
import java.util.Collections;
/*
* NOTE:
* This originated in gh-9208 (JwtEncoder),
* which is required to realize the feature in gh-8175 (JWT Client Authentication).
* However, we decided not to merge gh-9208 as part of the 5.5.0 release
* and instead packaged it up privately with the gh-8175 feature.
* We MAY merge gh-9208 in a later release but that is yet to be determined.
*
* gh-9208 Introduce JwtEncoder
* https://github.com/spring-projects/spring-security/pull/9208
*
* gh-8175 Support JWT for Client Authentication
* https://github.com/spring-projects/spring-security/issues/8175
*/
/**
* @author Joe Grandja
*/
final class TestJwtClaimsSets {
private TestJwtClaimsSets() {
}
static JwtClaimsSet.Builder jwtClaimsSet() {
String issuer = "https://provider.com";
Instant issuedAt = Instant.now();
Instant expiresAt = issuedAt.plus(1, ChronoUnit.HOURS);
// @formatter:off
return JwtClaimsSet.builder()
.issuer(issuer)
.subject("subject")
.audience(Collections.singletonList("client-1"))
.issuedAt(issuedAt)
.notBefore(issuedAt)
.expiresAt(expiresAt)
.id("jti")
.claim("custom-claim-name", "custom-claim-value");
// @formatter:on
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -58,6 +58,17 @@ public final class ClientAuthenticationMethod implements Serializable {
public static final ClientAuthenticationMethod CLIENT_SECRET_POST = new ClientAuthenticationMethod(
"client_secret_post");
/**
* @since 5.5
*/
public static final ClientAuthenticationMethod CLIENT_SECRET_JWT = new ClientAuthenticationMethod(
"client_secret_jwt");
/**
* @since 5.5
*/
public static final ClientAuthenticationMethod PRIVATE_KEY_JWT = new ClientAuthenticationMethod("private_key_jwt");
/**
* @since 5.2
*/

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -48,6 +48,18 @@ public interface OAuth2ParameterNames {
*/
String CLIENT_SECRET = "client_secret";
/**
* {@code client_assertion_type} - used in Access Token Request.
* @since 5.5
*/
String CLIENT_ASSERTION_TYPE = "client_assertion_type";
/**
* {@code client_assertion} - used in Access Token Request.
* @since 5.5
*/
String CLIENT_ASSERTION = "client_assertion";
/**
* {@code redirect_uri} - used in Authorization Request and Access Token Request.
*/

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2020 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -53,6 +53,16 @@ public class ClientAuthenticationMethodTests {
assertThat(ClientAuthenticationMethod.CLIENT_SECRET_POST.getValue()).isEqualTo("client_secret_post");
}
@Test
public void getValueWhenAuthenticationMethodClientSecretJwtThenReturnClientSecretJwt() {
assertThat(ClientAuthenticationMethod.CLIENT_SECRET_JWT.getValue()).isEqualTo("client_secret_jwt");
}
@Test
public void getValueWhenAuthenticationMethodPrivateKeyJwtThenReturnPrivateKeyJwt() {
assertThat(ClientAuthenticationMethod.PRIVATE_KEY_JWT.getValue()).isEqualTo("private_key_jwt");
}
@Test
public void getValueWhenAuthenticationMethodNoneThenReturnNone() {
assertThat(ClientAuthenticationMethod.NONE.getValue()).isEqualTo("none");

View File

@ -0,0 +1,86 @@
/*
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.jose;
import java.security.interfaces.ECPrivateKey;
import java.security.interfaces.ECPublicKey;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import javax.crypto.SecretKey;
import com.nimbusds.jose.jwk.Curve;
import com.nimbusds.jose.jwk.ECKey;
import com.nimbusds.jose.jwk.OctetSequenceKey;
import com.nimbusds.jose.jwk.RSAKey;
/**
* @author Joe Grandja
*/
public final class TestJwks {
// @formatter:off
public static final RSAKey DEFAULT_RSA_JWK =
jwk(
TestKeys.DEFAULT_PUBLIC_KEY,
TestKeys.DEFAULT_PRIVATE_KEY
).build();
// @formatter:on
// @formatter:off
public static final ECKey DEFAULT_EC_JWK =
jwk(
(ECPublicKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPublic(),
(ECPrivateKey) TestKeys.DEFAULT_EC_KEY_PAIR.getPrivate()
).build();
// @formatter:on
// @formatter:off
public static final OctetSequenceKey DEFAULT_SECRET_JWK =
jwk(
TestKeys.DEFAULT_SECRET_KEY
).build();
// @formatter:on
private TestJwks() {
}
public static RSAKey.Builder jwk(RSAPublicKey publicKey, RSAPrivateKey privateKey) {
// @formatter:off
return new RSAKey.Builder(publicKey)
.privateKey(privateKey)
.keyID("rsa-jwk-kid");
// @formatter:on
}
public static ECKey.Builder jwk(ECPublicKey publicKey, ECPrivateKey privateKey) {
// @formatter:off
Curve curve = Curve.forECParameterSpec(publicKey.getParams());
return new ECKey.Builder(curve, publicKey)
.privateKey(privateKey)
.keyID("ec-jwk-kid");
// @formatter:on
}
public static OctetSequenceKey.Builder jwk(SecretKey secretKey) {
// @formatter:off
return new OctetSequenceKey.Builder(secretKey)
.keyID("secret-jwk-kid");
// @formatter:on
}
}

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2019 the original author or authors.
* Copyright 2002-2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,10 +16,17 @@
package org.springframework.security.oauth2.jose;
import java.math.BigInteger;
import java.security.KeyFactory;
import java.security.KeyPair;
import java.security.KeyPairGenerator;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPrivateKey;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.ECFieldFp;
import java.security.spec.ECParameterSpec;
import java.security.spec.ECPoint;
import java.security.spec.EllipticCurve;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.PKCS8EncodedKeySpec;
import java.security.spec.X509EncodedKeySpec;
@ -109,6 +116,34 @@ public final class TestKeys {
}
}
public static final KeyPair DEFAULT_RSA_KEY_PAIR = new KeyPair(DEFAULT_PUBLIC_KEY, DEFAULT_PRIVATE_KEY);
public static final KeyPair DEFAULT_EC_KEY_PAIR = generateEcKeyPair();
static KeyPair generateEcKeyPair() {
EllipticCurve ellipticCurve = new EllipticCurve(
new ECFieldFp(new BigInteger(
"115792089210356248762697446949407573530086143415290314195533631308867097853951")),
new BigInteger("115792089210356248762697446949407573530086143415290314195533631308867097853948"),
new BigInteger("41058363725152142129326129780047268409114441015993725554835256314039467401291"));
ECPoint ecPoint = new ECPoint(
new BigInteger("48439561293906451759052585252797914202762949526041747995844080717082404635286"),
new BigInteger("36134250956749795798585127919587881956611106672985015071877198253568414405109"));
ECParameterSpec ecParameterSpec = new ECParameterSpec(ellipticCurve, ecPoint,
new BigInteger("115792089210356248762697446949407573529996955224135760342422259061068512044369"), 1);
KeyPair keyPair;
try {
KeyPairGenerator keyPairGenerator = KeyPairGenerator.getInstance("EC");
keyPairGenerator.initialize(ecParameterSpec);
keyPair = keyPairGenerator.generateKeyPair();
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
return keyPair;
}
private TestKeys() {
}