Polish oauth-client format

Issue gh-8945
This commit is contained in:
Rob Winch 2020-08-24 09:49:04 -05:00
parent 38aae7f015
commit dc47a7575e
97 changed files with 3421 additions and 1145 deletions

View File

@ -163,14 +163,16 @@ public final class AuthorizedClientServiceOAuth2AuthorizedClientManager implemen
private OAuth2AuthorizationContext buildAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest, private OAuth2AuthorizationContext buildAuthorizationContext(OAuth2AuthorizeRequest authorizeRequest,
Authentication principal, OAuth2AuthorizationContext.Builder contextBuilder) { Authentication principal, OAuth2AuthorizationContext.Builder contextBuilder) {
OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal) // @formatter:off
return contextBuilder.principal(principal)
.attributes((attributes) -> { .attributes((attributes) -> {
Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest);
if (!CollectionUtils.isEmpty(contextAttributes)) { if (!CollectionUtils.isEmpty(contextAttributes)) {
attributes.putAll(contextAttributes); attributes.putAll(contextAttributes);
} }
}).build(); })
return authorizationContext; .build();
// @formatter:on
} }
/** /**

View File

@ -81,9 +81,12 @@ public final class InMemoryReactiveOAuth2AuthorizedClientService implements Reac
public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) { public Mono<Void> removeAuthorizedClient(String clientRegistrationId, String principalName) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.hasText(principalName, "principalName cannot be empty"); Assert.hasText(principalName, "principalName cannot be empty");
// @formatter:off
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName)) .map((clientRegistration) -> new OAuth2AuthorizedClientId(clientRegistrationId, principalName))
.doOnNext(this.authorizedClients::remove).then(Mono.empty()); .doOnNext(this.authorizedClients::remove)
.then(Mono.empty());
// @formatter:on
} }
} }

View File

@ -64,26 +64,42 @@ import org.springframework.util.StringUtils;
*/ */
public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService { public class JdbcOAuth2AuthorizedClientService implements OAuth2AuthorizedClientService {
private static final String COLUMN_NAMES = "client_registration_id, " + "principal_name, " + "access_token_type, " // @formatter:off
+ "access_token_value, " + "access_token_issued_at, " + "access_token_expires_at, " private static final String COLUMN_NAMES = "client_registration_id, "
+ "access_token_scopes, " + "refresh_token_value, " + "refresh_token_issued_at"; + "principal_name, "
+ "access_token_type, "
+ "access_token_value, "
+ "access_token_issued_at, "
+ "access_token_expires_at, "
+ "access_token_scopes, "
+ "refresh_token_value, "
+ "refresh_token_issued_at";
// @formatter:on
private static final String TABLE_NAME = "oauth2_authorized_client"; private static final String TABLE_NAME = "oauth2_authorized_client";
private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?"; private static final String PK_FILTER = "client_registration_id = ? AND principal_name = ?";
private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES + " FROM " + TABLE_NAME // @formatter:off
private static final String LOAD_AUTHORIZED_CLIENT_SQL = "SELECT " + COLUMN_NAMES
+ " FROM " + TABLE_NAME
+ " WHERE " + PK_FILTER; + " WHERE " + PK_FILTER;
// @formatter:on
private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME + " (" + COLUMN_NAMES // @formatter:off
+ ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; private static final String SAVE_AUTHORIZED_CLIENT_SQL = "INSERT INTO " + TABLE_NAME
+ " (" + COLUMN_NAMES + ") VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)";
// @formatter:on
private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER; private static final String REMOVE_AUTHORIZED_CLIENT_SQL = "DELETE FROM " + TABLE_NAME + " WHERE " + PK_FILTER;
// @formatter:off
private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME private static final String UPDATE_AUTHORIZED_CLIENT_SQL = "UPDATE " + TABLE_NAME
+ " SET access_token_type = ?, access_token_value = ?, access_token_issued_at = ?," + " SET access_token_type = ?, access_token_value = ?, access_token_issued_at = ?,"
+ " access_token_expires_at = ?, access_token_scopes = ?," + " access_token_expires_at = ?, access_token_scopes = ?,"
+ " refresh_token_value = ?, refresh_token_issued_at = ?" + " WHERE " + PK_FILTER; + " refresh_token_value = ?, refresh_token_issued_at = ?"
+ " WHERE " + PK_FILTER;
// @formatter:on
protected final JdbcOperations jdbcOperations; protected final JdbcOperations jdbcOperations;

View File

@ -68,11 +68,15 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
@Override @Override
public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) { public Mono<OAuth2AccessTokenResponse> getTokenResponse(T grantRequest) {
Assert.notNull(grantRequest, "grantRequest cannot be null"); Assert.notNull(grantRequest, "grantRequest cannot be null");
return Mono.defer( // @formatter:off
() -> this.webClient.post().uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri()) return Mono.defer(() -> this.webClient.post()
.headers((headers) -> populateTokenRequestHeaders(grantRequest, headers)) .uri(clientRegistration(grantRequest).getProviderDetails().getTokenUri())
.body(createTokenRequestBody(grantRequest)).exchange() .headers((headers) -> populateTokenRequestHeaders(grantRequest, headers))
.flatMap((response) -> readTokenResponse(grantRequest, response))); .body(createTokenRequestBody(grantRequest))
.exchange()
.flatMap((response) -> readTokenResponse(grantRequest, response))
);
// @formatter:on
} }
/** /**
@ -187,7 +191,12 @@ public abstract class AbstractWebClientReactiveOAuth2AccessTokenResponseClient<T
OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) { OAuth2AccessTokenResponse populateTokenResponse(T grantRequest, OAuth2AccessTokenResponse tokenResponse) {
if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) { if (CollectionUtils.isEmpty(tokenResponse.getAccessToken().getScopes())) {
Set<String> defaultScopes = defaultScopes(grantRequest); Set<String> defaultScopes = defaultScopes(grantRequest);
tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse).scopes(defaultScopes).build(); // @formatter:off
tokenResponse = OAuth2AccessTokenResponse
.withResponse(tokenResponse)
.scopes(defaultScopes)
.build();
// @formatter:on
} }
return tokenResponse; return tokenResponse;
} }

View File

@ -82,8 +82,11 @@ public final class DefaultAuthorizationCodeTokenResponseClient
// https://tools.ietf.org/html/rfc6749#section-5.1 // https://tools.ietf.org/html/rfc6749#section-5.1
// If AccessTokenResponse.scope is empty, then default to the scope // If AccessTokenResponse.scope is empty, then default to the scope
// originally requested by the client in the Token Request // originally requested by the client in the Token Request
// @formatter:off
tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
.scopes(authorizationCodeGrantRequest.getClientRegistration().getScopes()).build(); .scopes(authorizationCodeGrantRequest.getClientRegistration().getScopes())
.build();
// @formatter:on
} }
return tokenResponse; return tokenResponse;
} }

View File

@ -82,8 +82,11 @@ public final class DefaultClientCredentialsTokenResponseClient
// https://tools.ietf.org/html/rfc6749#section-5.1 // https://tools.ietf.org/html/rfc6749#section-5.1
// If AccessTokenResponse.scope is empty, then default to the scope // If AccessTokenResponse.scope is empty, then default to the scope
// originally requested by the client in the Token Request // originally requested by the client in the Token Request
// @formatter:off
tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse) tokenResponse = OAuth2AccessTokenResponse.withResponse(tokenResponse)
.scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes()).build(); .scopes(clientCredentialsGrantRequest.getClientRegistration().getScopes())
.build();
// @formatter:on
} }
return tokenResponse; return tokenResponse;
} }

View File

@ -119,8 +119,15 @@ public class NimbusAuthorizationCodeTokenResponseClient
refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue(); refreshToken = accessTokenResponse.getTokens().getRefreshToken().getValue();
} }
Map<String, Object> additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters()); Map<String, Object> additionalParameters = new LinkedHashMap<>(accessTokenResponse.getCustomParameters());
return OAuth2AccessTokenResponse.withToken(accessToken).tokenType(accessTokenType).expiresIn(expiresIn) // @formatter:off
.scopes(scopes).refreshToken(refreshToken).additionalParameters(additionalParameters).build(); return OAuth2AccessTokenResponse.withToken(accessToken)
.tokenType(accessTokenType)
.expiresIn(expiresIn)
.scopes(scopes)
.refreshToken(refreshToken)
.additionalParameters(additionalParameters)
.build();
// @formatter:on
} }
private com.nimbusds.oauth2.sdk.TokenResponse getTokenResponse(AuthorizationGrant authorizationCodeGrant, private com.nimbusds.oauth2.sdk.TokenResponse getTokenResponse(AuthorizationGrant authorizationCodeGrant,

View File

@ -191,24 +191,31 @@ public class OidcAuthorizationCodeReactiveAuthenticationManager implements React
null); null);
return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString())); return Mono.error(new OAuth2AuthenticationException(invalidIdTokenError, invalidIdTokenError.toString()));
} }
// @formatter:off
return createOidcToken(clientRegistration, accessTokenResponse) return createOidcToken(clientRegistration, accessTokenResponse)
.doOnNext((idToken) -> validateNonce(authorizationCodeAuthentication, idToken)) .doOnNext((idToken) -> validateNonce(authorizationCodeAuthentication, idToken))
.map((idToken) -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters)) .map((idToken) -> new OidcUserRequest(clientRegistration, accessToken, idToken, additionalParameters))
.flatMap(this.userService::loadUser).map((oauth2User) -> { .flatMap(this.userService::loadUser)
.map((oauth2User) -> {
Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper Collection<? extends GrantedAuthority> mappedAuthorities = this.authoritiesMapper
.mapAuthorities(oauth2User.getAuthorities()); .mapAuthorities(oauth2User.getAuthorities());
return new OAuth2LoginAuthenticationToken(authorizationCodeAuthentication.getClientRegistration(), return new OAuth2LoginAuthenticationToken(authorizationCodeAuthentication.getClientRegistration(),
authorizationCodeAuthentication.getAuthorizationExchange(), oauth2User, mappedAuthorities, authorizationCodeAuthentication.getAuthorizationExchange(), oauth2User, mappedAuthorities,
accessToken, accessTokenResponse.getRefreshToken()); accessToken, accessTokenResponse.getRefreshToken());
}); });
// @formatter:on
} }
private Mono<OidcIdToken> createOidcToken(ClientRegistration clientRegistration, private Mono<OidcIdToken> createOidcToken(ClientRegistration clientRegistration,
OAuth2AccessTokenResponse accessTokenResponse) { OAuth2AccessTokenResponse accessTokenResponse) {
ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration); ReactiveJwtDecoder jwtDecoder = this.jwtDecoderFactory.createDecoder(clientRegistration);
String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN); String rawIdToken = (String) accessTokenResponse.getAdditionalParameters().get(OidcParameterNames.ID_TOKEN);
return jwtDecoder.decode(rawIdToken).map( // @formatter:off
(jwt) -> new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims())); return jwtDecoder.decode(rawIdToken)
.map((jwt) ->
new OidcIdToken(jwt.getTokenValue(), jwt.getIssuedAt(), jwt.getExpiresAt(), jwt.getClaims())
);
// @formatter:on
} }
private static Mono<OidcIdToken> validateNonce( private static Mono<OidcIdToken> validateNonce(

View File

@ -97,8 +97,13 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
@Override @Override
public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException { public Mono<OidcUser> loadUser(OidcUserRequest userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null"); Assert.notNull(userRequest, "userRequest cannot be null");
return getUserInfo(userRequest).map((userInfo) -> new OidcUserAuthority(userRequest.getIdToken(), userInfo)) // @formatter:off
.defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null)).map((authority) -> { return getUserInfo(userRequest)
.map((userInfo) ->
new OidcUserAuthority(userRequest.getIdToken(), userInfo)
)
.defaultIfEmpty(new OidcUserAuthority(userRequest.getIdToken(), null))
.map((authority) -> {
OidcUserInfo userInfo = authority.getUserInfo(); OidcUserInfo userInfo = authority.getUserInfo();
Set<GrantedAuthority> authorities = new HashSet<>(); Set<GrantedAuthority> authorities = new HashSet<>();
authorities.add(authority); authorities.add(authority);
@ -114,14 +119,19 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
} }
return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo); return new DefaultOidcUser(authorities, userRequest.getIdToken(), userInfo);
}); });
// @formatter:on
} }
private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) { private Mono<OidcUserInfo> getUserInfo(OidcUserRequest userRequest) {
if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) { if (!OidcUserRequestUtils.shouldRetrieveUserInfo(userRequest)) {
return Mono.empty(); return Mono.empty();
} }
return this.oauth2UserService.loadUser(userRequest).map(OAuth2User::getAttributes) // @formatter:off
.map((claims) -> convertClaims(claims, userRequest.getClientRegistration())).map(OidcUserInfo::new) return this.oauth2UserService
.loadUser(userRequest)
.map(OAuth2User::getAttributes)
.map((claims) -> convertClaims(claims, userRequest.getClientRegistration()))
.map(OidcUserInfo::new)
.doOnNext((userInfo) -> { .doOnNext((userInfo) -> {
String subject = userInfo.getSubject(); String subject = userInfo.getSubject();
if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) { if (subject == null || !subject.equals(userRequest.getIdToken().getSubject())) {
@ -129,6 +139,7 @@ public class OidcReactiveOAuth2UserService implements ReactiveOAuth2UserService<
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
}); });
// @formatter:on
} }
private Map<String, Object> convertClaims(Map<String, Object> claims, ClientRegistration clientRegistration) { private Map<String, Object> convertClaims(Map<String, Object> claims, ClientRegistration clientRegistration) {

View File

@ -93,10 +93,17 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo
if (this.postLogoutRedirectUri == null) { if (this.postLogoutRedirectUri == null) {
return null; return null;
} }
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) // @formatter:off
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); UriComponents uriComponents = UriComponentsBuilder
.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath())
.replaceQuery(null)
.fragment(null)
.build();
return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri) return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri)
.buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString())).toUri(); .buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString()))
.toUri();
// @formatter:on
} }
private String endpointUri(URI endSessionEndpoint, String idToken, URI postLogoutRedirectUri) { private String endpointUri(URI endSessionEndpoint, String idToken, URI postLogoutRedirectUri) {
@ -105,7 +112,11 @@ public final class OidcClientInitiatedLogoutSuccessHandler extends SimpleUrlLogo
if (postLogoutRedirectUri != null) { if (postLogoutRedirectUri != null) {
builder.queryParam("post_logout_redirect_uri", postLogoutRedirectUri); builder.queryParam("post_logout_redirect_uri", postLogoutRedirectUri);
} }
return builder.encode(StandardCharsets.UTF_8).build().toUriString(); // @formatter:off
return builder.encode(StandardCharsets.UTF_8)
.build()
.toUriString();
// @formatter:on
} }
/** /**

View File

@ -72,11 +72,14 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo
@Override @Override
public Mono<Void> onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) { public Mono<Void> onLogoutSuccess(WebFilterExchange exchange, Authentication authentication) {
return Mono.just(authentication).filter(OAuth2AuthenticationToken.class::isInstance) // @formatter:off
return Mono.just(authentication)
.filter(OAuth2AuthenticationToken.class::isInstance)
.filter((token) -> authentication.getPrincipal() instanceof OidcUser) .filter((token) -> authentication.getPrincipal() instanceof OidcUser)
.map(OAuth2AuthenticationToken.class::cast) .map(OAuth2AuthenticationToken.class::cast)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId) .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId)
.flatMap(this.clientRegistrationRepository::findByRegistrationId).flatMap((clientRegistration) -> { .flatMap(this.clientRegistrationRepository::findByRegistrationId)
.flatMap((clientRegistration) -> {
URI endSessionEndpoint = endSessionEndpoint(clientRegistration); URI endSessionEndpoint = endSessionEndpoint(clientRegistration);
if (endSessionEndpoint == null) { if (endSessionEndpoint == null) {
return Mono.empty(); return Mono.empty();
@ -86,8 +89,10 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo
return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri)); return Mono.just(endpointUri(endSessionEndpoint, idToken, postLogoutRedirectUri));
}) })
.switchIfEmpty( .switchIfEmpty(
this.serverLogoutSuccessHandler.onLogoutSuccess(exchange, authentication).then(Mono.empty())) this.serverLogoutSuccessHandler.onLogoutSuccess(exchange, authentication).then(Mono.empty())
)
.flatMap((endpointUri) -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri)); .flatMap((endpointUri) -> this.redirectStrategy.sendRedirect(exchange.getExchange(), endpointUri));
// @formatter:on
} }
private URI endSessionEndpoint(ClientRegistration clientRegistration) { private URI endSessionEndpoint(ClientRegistration clientRegistration) {
@ -118,10 +123,16 @@ public class OidcClientInitiatedServerLogoutSuccessHandler implements ServerLogo
if (this.postLogoutRedirectUri == null) { if (this.postLogoutRedirectUri == null) {
return null; return null;
} }
// @formatter:off
UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI())
.replacePath(request.getPath().contextPath().value()).replaceQuery(null).fragment(null).build(); .replacePath(request.getPath().contextPath().value())
.replaceQuery(null)
.fragment(null)
.build();
return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri) return UriComponentsBuilder.fromUriString(this.postLogoutRedirectUri)
.buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString())).toUri(); .buildAndExpand(Collections.singletonMap("baseUrl", uriComponents.toUriString()))
.toUri();
// @formatter:on
} }
/** /**

View File

@ -168,11 +168,19 @@ public final class ClientRegistration implements Serializable {
@Override @Override
public String toString() { public String toString() {
return "ClientRegistration{" + "registrationId='" + this.registrationId + '\'' + ", clientId='" + this.clientId // @formatter:off
+ '\'' + ", clientSecret='" + this.clientSecret + '\'' + ", clientAuthenticationMethod=" return "ClientRegistration{"
+ this.clientAuthenticationMethod + ", authorizationGrantType=" + this.authorizationGrantType + "registrationId='" + this.registrationId + '\''
+ ", redirectUri='" + this.redirectUri + '\'' + ", scopes=" + this.scopes + ", providerDetails=" + ", clientId='" + this.clientId + '\''
+ this.providerDetails + ", clientName='" + this.clientName + '\'' + '}'; + ", clientSecret='" + this.clientSecret + '\''
+ ", clientAuthenticationMethod=" + this.clientAuthenticationMethod
+ ", authorizationGrantType=" + this.authorizationGrantType
+ ", redirectUri='" + this.redirectUri
+ '\'' + ", scopes=" + this.scopes
+ ", providerDetails=" + this.providerDetails
+ ", clientName='" + this.clientName + '\''
+ '}';
// @formatter:on
} }
/** /**

View File

@ -148,8 +148,11 @@ public final class ClientRegistrations {
} }
private static Supplier<ClientRegistration.Builder> oidc(URI issuer) { private static Supplier<ClientRegistration.Builder> oidc(URI issuer) {
URI uri = UriComponentsBuilder.fromUri(issuer).replacePath(issuer.getPath() + OIDC_METADATA_PATH) // @formatter:off
URI uri = UriComponentsBuilder.fromUri(issuer)
.replacePath(issuer.getPath() + OIDC_METADATA_PATH)
.build(Collections.emptyMap()); .build(Collections.emptyMap());
// @formatter:on
return () -> { return () -> {
RequestEntity<Void> request = RequestEntity.get(uri).build(); RequestEntity<Void> request = RequestEntity.get(uri).build();
Map<String, Object> configuration = rest.exchange(request, typeReference).getBody(); Map<String, Object> configuration = rest.exchange(request, typeReference).getBody();
@ -164,14 +167,20 @@ public final class ClientRegistrations {
} }
private static Supplier<ClientRegistration.Builder> oidcRfc8414(URI issuer) { private static Supplier<ClientRegistration.Builder> oidcRfc8414(URI issuer) {
URI uri = UriComponentsBuilder.fromUri(issuer).replacePath(OIDC_METADATA_PATH + issuer.getPath()) // @formatter:off
URI uri = UriComponentsBuilder.fromUri(issuer)
.replacePath(OIDC_METADATA_PATH + issuer.getPath())
.build(Collections.emptyMap()); .build(Collections.emptyMap());
// @formatter:on
return getRfc8414Builder(issuer, uri); return getRfc8414Builder(issuer, uri);
} }
private static Supplier<ClientRegistration.Builder> oauth(URI issuer) { private static Supplier<ClientRegistration.Builder> oauth(URI issuer) {
URI uri = UriComponentsBuilder.fromUri(issuer).replacePath(OAUTH_METADATA_PATH + issuer.getPath()) // @formatter:off
URI uri = UriComponentsBuilder.fromUri(issuer)
.replacePath(OAUTH_METADATA_PATH + issuer.getPath())
.build(Collections.emptyMap()); .build(Collections.emptyMap());
// @formatter:on
return getRfc8414Builder(issuer, uri); return getRfc8414Builder(issuer, uri);
} }
@ -242,12 +251,19 @@ public final class ClientRegistrations {
+ "\" returned a configuration of " + grantTypes); + "\" returned a configuration of " + grantTypes);
List<String> scopes = getScopes(metadata); List<String> scopes = getScopes(metadata);
Map<String, Object> configurationMetadata = new LinkedHashMap<>(metadata.toJSONObject()); Map<String, Object> configurationMetadata = new LinkedHashMap<>(metadata.toJSONObject());
return ClientRegistration.withRegistrationId(name).userNameAttributeName(IdTokenClaimNames.SUB).scope(scopes) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).clientAuthenticationMethod(method) return ClientRegistration.withRegistrationId(name)
.userNameAttributeName(IdTokenClaimNames.SUB)
.scope(scopes)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientAuthenticationMethod(method)
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.authorizationUri(metadata.getAuthorizationEndpointURI().toASCIIString()) .authorizationUri(metadata.getAuthorizationEndpointURI().toASCIIString())
.providerConfigurationMetadata(configurationMetadata) .providerConfigurationMetadata(configurationMetadata)
.tokenUri(metadata.getTokenEndpointURI().toASCIIString()).issuerUri(issuer).clientName(issuer); .tokenUri(metadata.getTokenEndpointURI().toASCIIString())
.issuerUri(issuer)
.clientName(issuer);
// @formatter:on
} }
private static ClientAuthenticationMethod getClientAuthenticationMethod(String issuer, private static ClientAuthenticationMethod getClientAuthenticationMethod(String issuer,

View File

@ -108,13 +108,18 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi
.getUserInfoEndpoint().getAuthenticationMethod(); .getUserInfoEndpoint().getAuthenticationMethod();
WebClient.RequestHeadersSpec<?> requestHeadersSpec = getRequestHeaderSpec(userRequest, userInfoUri, WebClient.RequestHeadersSpec<?> requestHeadersSpec = getRequestHeaderSpec(userRequest, userInfoUri,
authenticationMethod); authenticationMethod);
// @formatter:off
Mono<Map<String, Object>> userAttributes = requestHeadersSpec.retrieve() Mono<Map<String, Object>> userAttributes = requestHeadersSpec.retrieve()
.onStatus((s) -> s != HttpStatus.OK, (response) -> parse(response).map((userInfoErrorResponse) -> { .onStatus((s) -> s != HttpStatus.OK, (response) ->
String description = userInfoErrorResponse.getErrorObject().getDescription(); parse(response)
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, description, .map((userInfoErrorResponse) -> {
null); String description = userInfoErrorResponse.getErrorObject().getDescription();
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, description,
})).bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP); null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
})
)
.bodyToMono(DefaultReactiveOAuth2UserService.STRING_OBJECT_MAP);
return userAttributes.map((attrs) -> { return userAttributes.map((attrs) -> {
GrantedAuthority authority = new OAuth2UserAuthority(attrs); GrantedAuthority authority = new OAuth2UserAuthority(attrs);
Set<GrantedAuthority> authorities = new HashSet<>(); Set<GrantedAuthority> authorities = new HashSet<>();
@ -125,40 +130,54 @@ public class DefaultReactiveOAuth2UserService implements ReactiveOAuth2UserServi
} }
return new DefaultOAuth2User(authorities, attrs, userNameAttributeName); return new DefaultOAuth2User(authorities, attrs, userNameAttributeName);
}).onErrorMap(IOException.class, })
.onErrorMap(IOException.class,
(ex) -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri, (ex) -> new AuthenticationServiceException("Unable to access the userInfoEndpoint " + userInfoUri,
ex)) ex)
.onErrorMap(UnsupportedMediaTypeException.class, (ex) -> { )
String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '" .onErrorMap(UnsupportedMediaTypeException.class, (ex) -> {
+ userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint() String errorMessage = "An error occurred while attempting to retrieve the UserInfo Resource from '"
.getUri() + userRequest.getClientRegistration().getProviderDetails().getUserInfoEndpoint()
+ "': response contains invalid content type '" + ex.getContentType().toString() + "'. " .getUri()
+ "The UserInfo Response should return a JSON object (content type 'application/json') " + "': response contains invalid content type '" + ex.getContentType().toString() + "'. "
+ "that contains a collection of name and value pairs of the claims about the authenticated End-User. " + "The UserInfo Response should return a JSON object (content type 'application/json') "
+ "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '" + "that contains a collection of name and value pairs of the claims about the authenticated End-User. "
+ userRequest.getClientRegistration().getRegistrationId() + "Please ensure the UserInfo Uri in UserInfoEndpoint for Client Registration '"
+ "' conforms to the UserInfo Endpoint, " + userRequest.getClientRegistration().getRegistrationId()
+ "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'"; + "' conforms to the UserInfo Endpoint, "
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage, + "as defined in OpenID Connect 1.0: 'https://openid.net/specs/openid-connect-core-1_0.html#UserInfo'";
null); OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, errorMessage,
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex); null);
}).onErrorMap((t) -> !(t instanceof AuthenticationServiceException), (t) -> { throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), ex);
OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE, })
"An error occurred reading the UserInfo Success response: " + t.getMessage(), null); .onErrorMap((t) -> !(t instanceof AuthenticationServiceException), (t) -> {
return new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), t); OAuth2Error oauth2Error = new OAuth2Error(INVALID_USER_INFO_RESPONSE_ERROR_CODE,
}); "An error occurred reading the UserInfo Success response: " + t.getMessage(), null);
return new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString(), t);
});
}); });
// @formatter:on
} }
private WebClient.RequestHeadersSpec<?> getRequestHeaderSpec(OAuth2UserRequest userRequest, String userInfoUri, private WebClient.RequestHeadersSpec<?> getRequestHeaderSpec(OAuth2UserRequest userRequest, String userInfoUri,
AuthenticationMethod authenticationMethod) { AuthenticationMethod authenticationMethod) {
if (AuthenticationMethod.FORM.equals(authenticationMethod)) { if (AuthenticationMethod.FORM.equals(authenticationMethod)) {
return this.webClient.post().uri(userInfoUri).header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) // @formatter:off
return this.webClient.post()
.uri(userInfoUri)
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE) .header(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_FORM_URLENCODED_VALUE)
.syncBody("access_token=" + userRequest.getAccessToken().getTokenValue()); .bodyValue("access_token=" + userRequest.getAccessToken().getTokenValue());
// @formatter:on
} }
return this.webClient.get().uri(userInfoUri).header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) // @formatter:off
.headers((headers) -> headers.setBearerAuth(userRequest.getAccessToken().getTokenValue())); return this.webClient.get()
.uri(userInfoUri)
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.headers((headers) -> headers
.setBearerAuth(userRequest.getAccessToken().getTokenValue())
);
// @formatter:on
} }
/** /**

View File

@ -58,8 +58,13 @@ public class DelegatingOAuth2UserService<R extends OAuth2UserRequest, U extends
@Override @Override
public U loadUser(R userRequest) throws OAuth2AuthenticationException { public U loadUser(R userRequest) throws OAuth2AuthenticationException {
Assert.notNull(userRequest, "userRequest cannot be null"); Assert.notNull(userRequest, "userRequest cannot be null");
return this.userServices.stream().map((userService) -> userService.loadUser(userRequest)) // @formatter:off
.filter(Objects::nonNull).findFirst().orElse(null); return this.userServices.stream()
.map((userService) -> userService.loadUser(userRequest))
.filter(Objects::nonNull)
.findFirst()
.orElse(null);
// @formatter:on
} }
} }

View File

@ -153,10 +153,14 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction); String redirectUriStr = expandRedirectUri(request, clientRegistration, redirectUriAction);
// @formatter:off
builder.clientId(clientRegistration.getClientId()) builder.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) .redirectUri(redirectUriStr)
.state(this.stateGenerator.generateKey()).attributes(attributes); .scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes);
// @formatter:on
this.authorizationRequestCustomizer.accept(builder); this.authorizationRequestCustomizer.accept(builder);
@ -219,8 +223,13 @@ public final class DefaultOAuth2AuthorizationRequestResolver implements OAuth2Au
String action) { String action) {
Map<String, String> uriVariables = new HashMap<>(); Map<String, String> uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId()); uriVariables.put("registrationId", clientRegistration.getRegistrationId());
// @formatter:off
UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)) UriComponents uriComponents = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replacePath(request.getContextPath()).replaceQuery(null).fragment(null).build(); .replacePath(request.getContextPath())
.replaceQuery(null)
.fragment(null)
.build();
// @formatter:on
String scheme = uriComponents.getScheme(); String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
String host = uriComponents.getHost(); String host = uriComponents.getHost();

View File

@ -87,8 +87,14 @@ import org.springframework.web.context.request.ServletRequestAttributes;
*/ */
public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager { public final class DefaultOAuth2AuthorizedClientManager implements OAuth2AuthorizedClientManager {
private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = OAuth2AuthorizedClientProviderBuilder // @formatter:off
.builder().authorizationCode().refreshToken().clientCredentials().password().build(); private static final OAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.password()
.build();
// @formatter:on
private final ClientRegistrationRepository clientRegistrationRepository; private final ClientRegistrationRepository clientRegistrationRepository;
@ -156,13 +162,16 @@ public final class DefaultOAuth2AuthorizedClientManager implements OAuth2Authori
contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration); contextBuilder = OAuth2AuthorizationContext.withClientRegistration(clientRegistration);
} }
} }
// @formatter:off
OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal) OAuth2AuthorizationContext authorizationContext = contextBuilder.principal(principal)
.attributes((attributes) -> { .attributes((attributes) -> {
Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest); Map<String, Object> contextAttributes = this.contextAttributesMapper.apply(authorizeRequest);
if (!CollectionUtils.isEmpty(contextAttributes)) { if (!CollectionUtils.isEmpty(contextAttributes)) {
attributes.putAll(contextAttributes); attributes.putAll(contextAttributes);
} }
}).build(); })
.build();
// @formatter:on
try { try {
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
} }

View File

@ -93,11 +93,20 @@ import org.springframework.web.server.ServerWebExchange;
*/ */
public final class DefaultReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager { public final class DefaultReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = ReactiveOAuth2AuthorizedClientProviderBuilder // @formatter:off
.builder().authorizationCode().refreshToken().clientCredentials().password().build(); private static final ReactiveOAuth2AuthorizedClientProvider DEFAULT_AUTHORIZED_CLIENT_PROVIDER = ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.password()
.build();
// @formatter:on
// @formatter:off
private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext() private static final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext()
.filter((c) -> c.hasKey(ServerWebExchange.class)).map((c) -> c.get(ServerWebExchange.class)); .filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class));
// @formatter:on
private final ReactiveClientRegistrationRepository clientRegistrationRepository; private final ReactiveClientRegistrationRepository clientRegistrationRepository;
@ -138,28 +147,32 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId(); String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal(); Authentication principal = authorizeRequest.getPrincipal();
// @formatter:off
return Mono.justOrEmpty(authorizeRequest.<ServerWebExchange>getAttribute(ServerWebExchange.class.getName())) return Mono.justOrEmpty(authorizeRequest.<ServerWebExchange>getAttribute(ServerWebExchange.class.getName()))
.switchIfEmpty(currentServerWebExchangeMono) .switchIfEmpty(currentServerWebExchangeMono)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null"))) .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("serverWebExchange cannot be null")))
.flatMap((serverWebExchange) -> Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) .flatMap((serverWebExchange) -> Mono
.switchIfEmpty(Mono .justOrEmpty(authorizeRequest.getAuthorizedClient())
.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange))) .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(clientRegistrationId, principal, serverWebExchange)))
.flatMap((authorizedClient) -> // Re-authorize .flatMap((authorizedClient) -> // Re-authorize
authorizationContext(authorizeRequest, authorizedClient).flatMap( authorizationContext(authorizeRequest, authorizedClient)
(authorizationContext) -> authorize(authorizationContext, principal, serverWebExchange)) .flatMap((authorizationContext) -> authorize(authorizationContext, principal, serverWebExchange))
// Default to the existing authorizedClient if the // Default to the existing authorizedClient if the
// client was not re-authorized // client was not re-authorized
.defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null) .defaultIfEmpty((authorizeRequest.getAuthorizedClient() != null)
? authorizeRequest.getAuthorizedClient() : authorizedClient)) ? authorizeRequest.getAuthorizedClient() : authorizedClient)
)
.switchIfEmpty(Mono.defer(() -> .switchIfEmpty(Mono.defer(() ->
// Authorize // Authorize
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException( .switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
"Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) "Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
.flatMap((clientRegistration) -> authorizationContext(authorizeRequest, .flatMap((clientRegistration) -> authorizationContext(authorizeRequest,
clientRegistration)) clientRegistration))
.flatMap((authorizationContext) -> authorize(authorizationContext, principal, .flatMap((authorizationContext) -> authorize(authorizationContext, principal,
serverWebExchange))))); serverWebExchange))))
);
// @formatter:on
} }
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal, private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, Authentication principal,
@ -181,18 +194,22 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
*/ */
private Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext authorizationContext, private Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext authorizationContext,
Authentication principal, ServerWebExchange serverWebExchange) { Authentication principal, ServerWebExchange serverWebExchange) {
// @formatter:off
return this.authorizedClientProvider.authorize(authorizationContext) return this.authorizedClientProvider.authorize(authorizationContext)
// Delegate to the authorizationSuccessHandler of the successful // Delegate to the authorizationSuccessHandler of the successful
// authorization // authorization
.flatMap((authorizedClient) -> this.authorizationSuccessHandler .flatMap((authorizedClient) ->
.onAuthorizationSuccess(authorizedClient, principal, createAttributes(serverWebExchange)) this.authorizationSuccessHandler
.thenReturn(authorizedClient)) .onAuthorizationSuccess(authorizedClient, principal, createAttributes(serverWebExchange))
.thenReturn(authorizedClient)
)
// Delegate to the authorizationFailureHandler of the failed authorization // Delegate to the authorizationFailureHandler of the failed authorization
.onErrorResume(OAuth2AuthorizationException.class, .onErrorResume(OAuth2AuthorizationException.class, (authorizationException) ->
(authorizationException) -> this.authorizationFailureHandler this.authorizationFailureHandler
.onAuthorizationFailure(authorizationException, principal, .onAuthorizationFailure(authorizationException, principal, createAttributes(serverWebExchange))
createAttributes(serverWebExchange)) .then(Mono.error(authorizationException))
.then(Mono.error(authorizationException))); );
// @formatter:on
} }
private Map<String, Object> createAttributes(ServerWebExchange serverWebExchange) { private Map<String, Object> createAttributes(ServerWebExchange serverWebExchange) {
@ -201,24 +218,36 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest, private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
OAuth2AuthorizedClient authorizedClient) { OAuth2AuthorizedClient authorizedClient) {
return Mono.just(authorizeRequest).flatMap(this.contextAttributesMapper) // @formatter:off
.map((attrs) -> OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient) return Mono.just(authorizeRequest)
.principal(authorizeRequest.getPrincipal()).attributes((attributes) -> { .flatMap(this.contextAttributesMapper)
.map((attrs) -> OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient)
.principal(authorizeRequest.getPrincipal())
.attributes((attributes) -> {
if (!CollectionUtils.isEmpty(attrs)) { if (!CollectionUtils.isEmpty(attrs)) {
attributes.putAll(attrs); attributes.putAll(attrs);
} }
}).build()); })
.build());
// @formatter:on
} }
private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest, private Mono<OAuth2AuthorizationContext> authorizationContext(OAuth2AuthorizeRequest authorizeRequest,
ClientRegistration clientRegistration) { ClientRegistration clientRegistration) {
return Mono.just(authorizeRequest).flatMap(this.contextAttributesMapper) // @formatter:off
return Mono.just(authorizeRequest)
.flatMap(this.contextAttributesMapper)
.map((attrs) -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration) .map((attrs) -> OAuth2AuthorizationContext.withClientRegistration(clientRegistration)
.principal(authorizeRequest.getPrincipal()).attributes((attributes) -> { .principal(authorizeRequest.getPrincipal())
.attributes((attributes) -> {
if (!CollectionUtils.isEmpty(attrs)) { if (!CollectionUtils.isEmpty(attrs)) {
attributes.putAll(attrs); attributes.putAll(attrs);
} }
}).build()); })
.build()
);
// @formatter:on
} }
/** /**
@ -286,7 +315,9 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
@Override @Override
public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) { public Mono<Map<String, Object>> apply(OAuth2AuthorizeRequest authorizeRequest) {
ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName()); ServerWebExchange serverWebExchange = authorizeRequest.getAttribute(ServerWebExchange.class.getName());
return Mono.justOrEmpty(serverWebExchange).switchIfEmpty(currentServerWebExchangeMono) // @formatter:off
return Mono.justOrEmpty(serverWebExchange)
.switchIfEmpty(currentServerWebExchangeMono)
.flatMap((exchange) -> { .flatMap((exchange) -> {
Map<String, Object> contextAttributes = Collections.emptyMap(); Map<String, Object> contextAttributes = Collections.emptyMap();
String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE); String scope = exchange.getRequest().getQueryParams().getFirst(OAuth2ParameterNames.SCOPE);
@ -296,7 +327,9 @@ public final class DefaultReactiveOAuth2AuthorizedClientManager implements React
StringUtils.delimitedListToStringArray(scope, " ")); StringUtils.delimitedListToStringArray(scope, " "));
} }
return Mono.just(contextAttributes); return Mono.just(contextAttributes);
}).defaultIfEmpty(Collections.emptyMap()); })
.defaultIfEmpty(Collections.emptyMap());
// @formatter:on
} }
} }

View File

@ -71,8 +71,14 @@ final class OAuth2AuthorizationResponseUtils {
} }
String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION);
String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI);
return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri).errorDescription(errorDescription) // @formatter:off
.errorUri(errorUri).state(state).build(); return OAuth2AuthorizationResponse.error(errorCode)
.redirectUri(redirectUri)
.errorDescription(errorDescription)
.errorUri(errorUri)
.state(state)
.build();
// @formatter:on
} }
} }

View File

@ -176,8 +176,12 @@ public class OAuth2LoginAuthenticationFilter extends AbstractAuthenticationProce
"Client Registration not found with Id: " + registrationId, null); "Client Registration not found with Id: " + registrationId, null);
throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString()); throw new OAuth2AuthenticationException(oauth2Error, oauth2Error.toString());
} }
String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request)).replaceQuery(null) // @formatter:off
.build().toUriString(); String redirectUri = UriComponentsBuilder.fromHttpUrl(UrlUtils.buildFullRequestUrl(request))
.replaceQuery(null)
.build()
.toUriString();
// @formatter:on
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params, OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponseUtils.convert(params,
redirectUri); redirectUri);
Object authenticationDetails = this.authenticationDetailsSource.buildDetails(request); Object authenticationDetails = this.authenticationDetailsSource.buildDetails(request);

View File

@ -126,9 +126,14 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
} }
HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class); HttpServletRequest servletRequest = webRequest.getNativeRequest(HttpServletRequest.class);
HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class); HttpServletResponse servletResponse = webRequest.getNativeResponse(HttpServletResponse.class);
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest.withClientRegistrationId(clientRegistrationId) // @formatter:off
.principal(principal).attribute(HttpServletRequest.class.getName(), servletRequest) OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.attribute(HttpServletResponse.class.getName(), servletResponse).build(); .withClientRegistrationId(clientRegistrationId)
.principal(principal)
.attribute(HttpServletRequest.class.getName(), servletRequest)
.attribute(HttpServletResponse.class.getName(), servletResponse)
.build();
// @formatter:on
return this.authorizedClientManager.authorize(authorizeRequest); return this.authorizedClientManager.authorize(authorizeRequest);
} }
@ -176,11 +181,16 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
private void updateDefaultAuthorizedClientManager( private void updateDefaultAuthorizedClientManager(
OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) { OAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode().refreshToken() .authorizationCode()
.clientCredentials( .refreshToken()
(configurer) -> configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)) .clientCredentials((configurer) ->
.password().build(); configurer.accessTokenResponseClient(clientCredentialsTokenResponseClient)
)
.password()
.build();
// @formatter:on
((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager) ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager)
.setAuthorizedClientProvider(authorizedClientProvider); .setAuthorizedClientProvider(authorizedClientProvider);
} }

View File

@ -131,12 +131,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext() private final Mono<Authentication> currentAuthenticationMono = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication).defaultIfEmpty(ANONYMOUS_USER_TOKEN); .map(SecurityContext::getAuthentication).defaultIfEmpty(ANONYMOUS_USER_TOKEN);
// @formatter:off
private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono private final Mono<String> clientRegistrationIdMono = this.currentAuthenticationMono
.filter((t) -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) .filter((t) -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class).map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); .cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
// @formatter:on
// @formatter:off
private final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext() private final Mono<ServerWebExchange> currentServerWebExchangeMono = Mono.subscriberContext()
.filter((c) -> c.hasKey(ServerWebExchange.class)).map((c) -> c.get(ServerWebExchange.class)); .filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class));
// @formatter:on
private final ReactiveOAuth2AuthorizedClientManager authorizedClientManager; private final ReactiveOAuth2AuthorizedClientManager authorizedClientManager;
@ -372,11 +378,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
} }
private void updateDefaultAuthorizedClientManager() { private void updateDefaultAuthorizedClientManager() {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder // @formatter:off
.builder().authorizationCode() ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode()
.refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)) .refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
.clientCredentials(this::updateClientCredentialsProvider) .clientCredentials(this::updateClientCredentialsProvider)
.password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)).build(); .password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
.build();
// @formatter:on
if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) { if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) {
((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager) ((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager)
.setAuthorizedClientProvider(authorizedClientProvider); .setAuthorizedClientProvider(authorizedClientProvider);
@ -418,9 +427,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override @Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return authorizedClient(request).map((authorizedClient) -> bearer(request, authorizedClient)) // @formatter:off
return authorizedClient(request)
.map((authorizedClient) -> bearer(request, authorizedClient))
.flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next)) .flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next))
.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next))); .switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
// @formatter:on
} }
private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
@ -430,22 +442,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) { private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request); OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
// @formatter:off
return Mono.justOrEmpty(authorizedClientFromAttrs) return Mono.justOrEmpty(authorizedClientFromAttrs)
.switchIfEmpty( .switchIfEmpty(Mono.defer(() -> authorizeRequest(request)
Mono.defer(() -> authorizeRequest(request).flatMap(this.authorizedClientManager::authorize))) .flatMap(this.authorizedClientManager::authorize))
)
.flatMap((authorizedClient) -> reauthorizeRequest(request, authorizedClient) .flatMap((authorizedClient) -> reauthorizeRequest(request, authorizedClient)
.flatMap(this.authorizedClientManager::authorize)); .flatMap(this.authorizedClientManager::authorize)
);
// @formatter:on
} }
private Mono<OAuth2AuthorizeRequest> authorizeRequest(ClientRequest request) { private Mono<OAuth2AuthorizeRequest> authorizeRequest(ClientRequest request) {
Mono<String> clientRegistrationId = effectiveClientRegistrationId(request); Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request); Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
return Mono.zip(clientRegistrationId, this.currentAuthenticationMono, serverWebExchange).map((t3) -> { // @formatter:off
OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withClientRegistrationId(t3.getT1()) return Mono.zip(clientRegistrationId, this.currentAuthenticationMono, serverWebExchange)
.principal(t3.getT2()); .map((t3) -> {
t3.getT3().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange)); OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest
return builder.build(); .withClientRegistrationId(t3.getT1())
}); .principal(t3.getT2());
t3.getT3().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange));
return builder.build();
});
// @formatter:on
} }
/** /**
@ -456,9 +476,11 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* given request. * given request.
*/ */
private Mono<String> effectiveClientRegistrationId(ClientRequest request) { private Mono<String> effectiveClientRegistrationId(ClientRequest request) {
// @formatter:off
return Mono.justOrEmpty(clientRegistrationId(request)) return Mono.justOrEmpty(clientRegistrationId(request))
.switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId)) .switchIfEmpty(Mono.justOrEmpty(this.defaultClientRegistrationId))
.switchIfEmpty(this.clientRegistrationIdMono); .switchIfEmpty(this.clientRegistrationIdMono);
// @formatter:on
} }
/** /**
@ -474,24 +496,34 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* {@link ServerWebExchange} that is active for the given request. * {@link ServerWebExchange} that is active for the given request.
*/ */
private Mono<Optional<ServerWebExchange>> effectiveServerWebExchange(ClientRequest request) { private Mono<Optional<ServerWebExchange>> effectiveServerWebExchange(ClientRequest request) {
return Mono.justOrEmpty(serverWebExchange(request)).switchIfEmpty(this.currentServerWebExchangeMono) // @formatter:off
.map(Optional::of).defaultIfEmpty(Optional.empty()); return Mono.justOrEmpty(serverWebExchange(request))
.switchIfEmpty(this.currentServerWebExchangeMono)
.map(Optional::of)
.defaultIfEmpty(Optional.empty());
// @formatter:on
} }
private Mono<OAuth2AuthorizeRequest> reauthorizeRequest(ClientRequest request, private Mono<OAuth2AuthorizeRequest> reauthorizeRequest(ClientRequest request,
OAuth2AuthorizedClient authorizedClient) { OAuth2AuthorizedClient authorizedClient) {
Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request); Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
return Mono.zip(this.currentAuthenticationMono, serverWebExchange).map((t2) -> { // @formatter:off
OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient) return Mono.zip(this.currentAuthenticationMono, serverWebExchange)
.principal(t2.getT1()); .map((t2) -> {
t2.getT2().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange)); OAuth2AuthorizeRequest.Builder builder = OAuth2AuthorizeRequest.withAuthorizedClient(authorizedClient)
return builder.build(); .principal(t2.getT1());
}); t2.getT2().ifPresent((exchange) -> builder.attribute(ServerWebExchange.class.getName(), exchange));
return builder.build();
});
// @formatter:on
} }
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
// @formatter:off
return ClientRequest.from(request) return ClientRequest.from(request)
.headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())).build(); .headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
.build();
// @formatter:on
} }
/** /**
@ -555,10 +587,12 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null"); Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
String clientRegistrationId = authorizeRequest.getClientRegistrationId(); String clientRegistrationId = authorizeRequest.getClientRegistrationId();
Authentication principal = authorizeRequest.getPrincipal(); Authentication principal = authorizeRequest.getPrincipal();
// @formatter:off
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient()) return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
.switchIfEmpty(loadAuthorizedClient(clientRegistrationId, principal)) .switchIfEmpty(loadAuthorizedClient(clientRegistrationId, principal))
.flatMap((authorizedClient) -> reauthorize(authorizedClient, authorizeRequest, principal)) .flatMap((authorizedClient) -> reauthorize(authorizedClient, authorizeRequest, principal))
.switchIfEmpty(findAndAuthorize(clientRegistrationId, principal)); .switchIfEmpty(findAndAuthorize(clientRegistrationId, principal));
// @formatter:on
} }
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId, private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
@ -580,12 +614,18 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
} }
private Mono<OAuth2AuthorizedClient> findAndAuthorize(String clientRegistrationId, Authentication principal) { private Mono<OAuth2AuthorizedClient> findAndAuthorize(String clientRegistrationId, Authentication principal) {
return Mono.defer(() -> this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) // @formatter:off
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException( return Mono.defer(() ->
"Could not find ClientRegistration with id '" + clientRegistrationId + "'"))) this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.flatMap((clientRegistration) -> Mono.just(OAuth2AuthorizationContext .switchIfEmpty(Mono.error(() ->
.withClientRegistration(clientRegistration).principal(principal).build())) new IllegalArgumentException("Could not find ClientRegistration with id '" + clientRegistrationId + "'"))
.flatMap((authorizationContext) -> authorize(authorizationContext, principal))); )
.flatMap((clientRegistration) -> Mono.just(OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(principal).build())
)
.flatMap((authorizationContext) -> authorize(authorizationContext, principal))
);
// @formatter:on
} }
/** /**
@ -601,18 +641,22 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
*/ */
private Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext authorizationContext, private Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizationContext authorizationContext,
Authentication principal) { Authentication principal) {
// @formatter:off
return this.authorizedClientProvider.authorize(authorizationContext) return this.authorizedClientProvider.authorize(authorizationContext)
// Delegates to the authorizationSuccessHandler of the successful // Delegates to the authorizationSuccessHandler of the successful
// authorization // authorization
.flatMap((authorizedClient) -> this.authorizationSuccessHandler .flatMap((authorizedClient) -> this.authorizationSuccessHandler
.onAuthorizationSuccess(authorizedClient, principal, Collections.emptyMap()) .onAuthorizationSuccess(authorizedClient, principal, Collections.emptyMap())
.thenReturn(authorizedClient)) .thenReturn(authorizedClient)
)
// Delegates to the authorizationFailureHandler of the failed // Delegates to the authorizationFailureHandler of the failed
// authorization // authorization
.onErrorResume(OAuth2AuthorizationException.class, .onErrorResume(OAuth2AuthorizationException.class, (authorizationException) ->
(authorizationException) -> this.authorizationFailureHandler this.authorizationFailureHandler
.onAuthorizationFailure(authorizationException, principal, Collections.emptyMap()) .onAuthorizationFailure(authorizationException, principal, Collections.emptyMap())
.then(Mono.error(authorizationException))); .then(Mono.error(authorizationException))
);
// @formatter:on
} }
private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) { private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
@ -653,23 +697,30 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override @Override
public Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> responseMono) { public Mono<ClientResponse> handleResponse(ClientRequest request, Mono<ClientResponse> responseMono) {
return responseMono.flatMap((response) -> handleResponse(request, response).thenReturn(response)) // @formatter:off
return responseMono
.flatMap((response) -> handleResponse(request, response).thenReturn(response))
.onErrorResume(WebClientResponseException.class, .onErrorResume(WebClientResponseException.class,
(e) -> handleWebClientResponseException(request, e).then(Mono.error(e))) (e) -> handleWebClientResponseException(request, e).then(Mono.error(e))
)
.onErrorResume(OAuth2AuthorizationException.class, .onErrorResume(OAuth2AuthorizationException.class,
(e) -> handleAuthorizationException(request, e).then(Mono.error(e))); (e) -> handleAuthorizationException(request, e).then(Mono.error(e)));
// @formatter:on
} }
private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) { private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) {
return Mono.justOrEmpty(resolveErrorIfPossible(response)).flatMap((oauth2Error) -> { // @formatter:off
Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request); return Mono.justOrEmpty(resolveErrorIfPossible(response))
Mono<String> clientRegistrationId = effectiveClientRegistrationId(request); .flatMap((oauth2Error) -> {
return Mono Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
.zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
serverWebExchange, clientRegistrationId) return Mono
.flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono,
new ClientAuthorizationException(oauth2Error, zipped.getT3()))); serverWebExchange, clientRegistrationId)
}); .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(),
new ClientAuthorizationException(oauth2Error, zipped.getT3())));
});
// @formatter:on
} }
private OAuth2Error resolveErrorIfPossible(ClientResponse response) { private OAuth2Error resolveErrorIfPossible(ClientResponse response) {
@ -695,13 +746,19 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
} }
private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) { private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
return Stream.of(wwwAuthenticateHeader).filter((header) -> !StringUtils.isEmpty(header)) // @formatter:off
return Stream.of(wwwAuthenticateHeader)
.filter((header) -> !StringUtils.isEmpty(header))
.filter((header) -> header.toLowerCase().startsWith("bearer")) .filter((header) -> header.toLowerCase().startsWith("bearer"))
.map((header) -> header.substring("bearer".length())).map((header) -> header.split(",")) .map((header) -> header.substring("bearer".length()))
.flatMap(Stream::of).map((parameter) -> parameter.split("=")) .map((header) -> header.split(","))
.flatMap(Stream::of)
.map((parameter) -> parameter.split("="))
.filter((parameter) -> parameter.length > 1) .filter((parameter) -> parameter.length > 1)
.collect(Collectors.toMap((parameters) -> parameters[0].trim(), .collect(Collectors.toMap((parameters) -> parameters[0].trim(),
(parameters) -> parameters[1].trim().replace("\"", ""))); (parameters) -> parameters[1].trim().replace("\"", ""))
);
// @formatter:on
} }
/** /**
@ -715,15 +772,16 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
*/ */
private Mono<Void> handleWebClientResponseException(ClientRequest request, private Mono<Void> handleWebClientResponseException(ClientRequest request,
WebClientResponseException exception) { WebClientResponseException exception) {
return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode())).flatMap((oauth2Error) -> { return Mono.justOrEmpty(resolveErrorIfPossible(exception.getRawStatusCode()))
Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request); .flatMap((oauth2Error) -> {
Mono<String> clientRegistrationId = effectiveClientRegistrationId(request); Mono<Optional<ServerWebExchange>> serverWebExchange = effectiveServerWebExchange(request);
return Mono Mono<String> clientRegistrationId = effectiveClientRegistrationId(request);
.zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono, return Mono
serverWebExchange, clientRegistrationId) .zip(ServerOAuth2AuthorizedClientExchangeFilterFunction.this.currentAuthenticationMono,
.flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(), serverWebExchange, clientRegistrationId)
new ClientAuthorizationException(oauth2Error, zipped.getT3(), exception))); .flatMap((zipped) -> handleAuthorizationFailure(zipped.getT1(), zipped.getT2(),
}); new ClientAuthorizationException(oauth2Error, zipped.getT3(), exception)));
});
} }
/** /**

View File

@ -262,10 +262,14 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
} }
private void updateDefaultAuthorizedClientManager() { private void updateDefaultAuthorizedClientManager() {
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode().refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)) .authorizationCode()
.refreshToken((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
.clientCredentials(this::updateClientCredentialsProvider) .clientCredentials(this::updateClientCredentialsProvider)
.password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew)).build(); .password((configurer) -> configurer.clockSkew(this.accessTokenExpiresSkew))
.build();
// @formatter:on
((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager) ((DefaultOAuth2AuthorizedClientManager) this.authorizedClientManager)
.setAuthorizedClientProvider(authorizedClientProvider); .setAuthorizedClientProvider(authorizedClientProvider);
} }
@ -435,15 +439,21 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
@Override @Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) { public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
// @formatter:off
return mergeRequestAttributesIfNecessary(request) return mergeRequestAttributesIfNecessary(request)
.filter((req) -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent()) .filter((req) -> req.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME).isPresent())
.flatMap((req) -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req)) .flatMap((req) -> reauthorizeClient(getOAuth2AuthorizedClient(req.attributes()), req))
.switchIfEmpty(Mono.defer(() -> mergeRequestAttributesIfNecessary(request) .switchIfEmpty(
.filter((req) -> resolveClientRegistrationId(req) != null) Mono.defer(() ->
.flatMap((req) -> authorizeClient(resolveClientRegistrationId(req), req)))) mergeRequestAttributesIfNecessary(request)
.filter((req) -> resolveClientRegistrationId(req) != null)
.flatMap((req) -> authorizeClient(resolveClientRegistrationId(req), req))
)
)
.map((authorizedClient) -> bearer(request, authorizedClient)) .map((authorizedClient) -> bearer(request, authorizedClient))
.flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next)) .flatMap((requestWithBearer) -> exchangeAndHandleResponse(requestWithBearer, next))
.switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next))); .switchIfEmpty(Mono.defer(() -> exchangeAndHandleResponse(request, next)));
// @formatter:on
} }
private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) { private Mono<ClientResponse> exchangeAndHandleResponse(ClientRequest request, ExchangeFunction next) {
@ -577,9 +587,12 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
} }
private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) { private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient authorizedClient) {
// @formatter:off
return ClientRequest.from(request) return ClientRequest.from(request)
.headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue())) .headers((headers) -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
.attributes(oauth2AuthorizedClient(authorizedClient)).build(); .attributes(oauth2AuthorizedClient(authorizedClient))
.build();
// @formatter:on
} }
static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) { static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> attrs) {
@ -664,19 +677,22 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
} }
private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) { private Mono<Void> handleResponse(ClientRequest request, ClientResponse response) {
return Mono.justOrEmpty(resolveErrorIfPossible(response)).flatMap((oauth2Error) -> { // @formatter:off
Map<String, Object> attrs = request.attributes(); return Mono.justOrEmpty(resolveErrorIfPossible(response))
OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs); .flatMap((oauth2Error) -> {
if (authorizedClient == null) { Map<String, Object> attrs = request.attributes();
return Mono.empty(); OAuth2AuthorizedClient authorizedClient = getOAuth2AuthorizedClient(attrs);
} if (authorizedClient == null) {
ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error, return Mono.empty();
authorizedClient.getClientRegistration().getRegistrationId()); }
Authentication principal = createAuthentication(authorizedClient.getPrincipalName()); ClientAuthorizationException authorizationException = new ClientAuthorizationException(oauth2Error,
HttpServletRequest servletRequest = getRequest(attrs); authorizedClient.getClientRegistration().getRegistrationId());
HttpServletResponse servletResponse = getResponse(attrs); Authentication principal = createAuthentication(authorizedClient.getPrincipalName());
return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse); HttpServletRequest servletRequest = getRequest(attrs);
}); HttpServletResponse servletResponse = getResponse(attrs);
return handleAuthorizationFailure(authorizationException, principal, servletRequest, servletResponse);
});
// @formatter:on
} }
private OAuth2Error resolveErrorIfPossible(ClientResponse response) { private OAuth2Error resolveErrorIfPossible(ClientResponse response) {
@ -702,13 +718,18 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
} }
private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) { private Map<String, String> parseAuthParameters(String wwwAuthenticateHeader) {
// @formatter:off
return Stream.of(wwwAuthenticateHeader).filter((header) -> !StringUtils.isEmpty(header)) return Stream.of(wwwAuthenticateHeader).filter((header) -> !StringUtils.isEmpty(header))
.filter((header) -> header.toLowerCase().startsWith("bearer")) .filter((header) -> header.toLowerCase().startsWith("bearer"))
.map((header) -> header.substring("bearer".length())).map((header) -> header.split(",")) .map((header) -> header.substring("bearer".length()))
.flatMap(Stream::of).map((parameter) -> parameter.split("=")) .map((header) -> header.split(","))
.flatMap(Stream::of)
.map((parameter) -> parameter.split("="))
.filter((parameter) -> parameter.length > 1) .filter((parameter) -> parameter.length > 1)
.collect(Collectors.toMap((parameters) -> parameters[0].trim(), .collect(Collectors.toMap((parameters) -> parameters[0].trim(),
(parameters) -> parameters[1].trim().replace("\"", ""))); (parameters) -> parameters[1].trim().replace("\"", ""))
);
// @formatter:on
} }
/** /**
@ -776,7 +797,11 @@ public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implement
HttpServletRequest servletRequest, HttpServletResponse servletResponse) { HttpServletRequest servletRequest, HttpServletResponse servletResponse) {
Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure(exception, principal, Runnable runnable = () -> this.authorizationFailureHandler.onAuthorizationFailure(exception, principal,
createAttributes(servletRequest, servletResponse)); createAttributes(servletRequest, servletResponse));
return Mono.fromRunnable(runnable).subscribeOn(Schedulers.boundedElastic()).then(); // @formatter:off
return Mono.fromRunnable(runnable)
.subscribeOn(Schedulers.boundedElastic())
.then();
// @formatter:on
} }
private static Map<String, Object> createAttributes(HttpServletRequest servletRequest, private static Map<String, Object> createAttributes(HttpServletRequest servletRequest,

View File

@ -126,8 +126,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
} }
private Mono<Authentication> currentAuthentication() { private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) // @formatter:off
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN); .defaultIfEmpty(ANONYMOUS_USER_TOKEN);
// @formatter:on
} }
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) { private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
@ -137,8 +140,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
} }
private Mono<ServerWebExchange> currentServerWebExchange() { private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext().filter((c) -> c.hasKey(ServerWebExchange.class)) // @formatter:off
return Mono.subscriberContext()
.filter((c) -> c.hasKey(ServerWebExchange.class))
.map((c) -> c.get(ServerWebExchange.class)); .map((c) -> c.get(ServerWebExchange.class));
// @formatter:on
} }
} }

View File

@ -120,10 +120,15 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA
@Override @Override
public Mono<OAuth2AuthorizationRequest> resolve(ServerWebExchange exchange) { public Mono<OAuth2AuthorizationRequest> resolve(ServerWebExchange exchange) {
return this.authorizationRequestMatcher.matches(exchange).filter((matchResult) -> matchResult.isMatch()) // @formatter:off
return this.authorizationRequestMatcher
.matches(exchange)
.filter((matchResult) -> matchResult.isMatch())
.map(ServerWebExchangeMatcher.MatchResult::getVariables) .map(ServerWebExchangeMatcher.MatchResult::getVariables)
.map((variables) -> variables.get(DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME)).cast(String.class) .map((variables) -> variables.get(DEFAULT_REGISTRATION_ID_URI_VARIABLE_NAME))
.cast(String.class)
.flatMap((clientRegistrationId) -> resolve(exchange, clientRegistrationId)); .flatMap((clientRegistrationId) -> resolve(exchange, clientRegistrationId));
// @formatter:on
} }
@Override @Override
@ -146,8 +151,10 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA
} }
private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) { private Mono<ClientRegistration> findByRegistrationId(ServerWebExchange exchange, String clientRegistration) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistration).switchIfEmpty(Mono // @formatter:off
.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id"))); return this.clientRegistrationRepository.findByRegistrationId(clientRegistration)
.switchIfEmpty(Mono.error(() -> new ResponseStatusException(HttpStatus.BAD_REQUEST, "Invalid client registration id")));
// @formatter:on
} }
private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange, private OAuth2AuthorizationRequest authorizationRequest(ServerWebExchange exchange,
@ -156,10 +163,14 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA
Map<String, Object> attributes = new HashMap<>(); Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId()); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, clientRegistration.getRegistrationId());
OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes); OAuth2AuthorizationRequest.Builder builder = getBuilder(clientRegistration, attributes);
// @formatter:off
builder.clientId(clientRegistration.getClientId()) builder.clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(redirectUriStr).scopes(clientRegistration.getScopes()) .redirectUri(redirectUriStr)
.state(this.stateGenerator.generateKey()).attributes(attributes); .scopes(clientRegistration.getScopes())
.state(this.stateGenerator.generateKey())
.attributes(attributes);
// @formatter:on
this.authorizationRequestCustomizer.accept(builder); this.authorizationRequestCustomizer.accept(builder);
@ -214,8 +225,13 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA
private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) { private static String expandRedirectUri(ServerHttpRequest request, ClientRegistration clientRegistration) {
Map<String, String> uriVariables = new HashMap<>(); Map<String, String> uriVariables = new HashMap<>();
uriVariables.put("registrationId", clientRegistration.getRegistrationId()); uriVariables.put("registrationId", clientRegistration.getRegistrationId());
// @formatter:off
UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI()) UriComponents uriComponents = UriComponentsBuilder.fromUri(request.getURI())
.replacePath(request.getPath().contextPath().value()).replaceQuery(null).fragment(null).build(); .replacePath(request.getPath().contextPath().value())
.replaceQuery(null)
.fragment(null)
.build();
// @formatter:on
String scheme = uriComponents.getScheme(); String scheme = uriComponents.getScheme();
uriVariables.put("baseScheme", (scheme != null) ? scheme : ""); uriVariables.put("baseScheme", (scheme != null) ? scheme : "");
String host = uriComponents.getHost(); String host = uriComponents.getHost();
@ -236,8 +252,11 @@ public class DefaultServerOAuth2AuthorizationRequestResolver implements ServerOA
action = "login"; action = "login";
} }
uriVariables.put("action", action); uriVariables.put("action", action);
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri()).buildAndExpand(uriVariables) // @formatter:off
return UriComponentsBuilder.fromUriString(clientRegistration.getRedirectUri())
.buildAndExpand(uriVariables)
.toUriString(); .toUriString();
// @formatter:on
} }
/** /**

View File

@ -202,15 +202,20 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
// @formatter:off
return this.requiresAuthenticationMatcher.matches(exchange) return this.requiresAuthenticationMatcher.matches(exchange)
.filter(ServerWebExchangeMatcher.MatchResult::isMatch) .filter(ServerWebExchangeMatcher.MatchResult::isMatch)
.flatMap((matchResult) -> this.authenticationConverter.convert(exchange).onErrorMap( .flatMap((matchResult) -> this.authenticationConverter.convert(exchange)
OAuth2AuthorizationException.class, .onErrorMap(OAuth2AuthorizationException.class,
(ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString()))) (ex) -> new OAuth2AuthenticationException(ex.getError(), ex.getError().toString())
)
)
.switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.flatMap((token) -> authenticate(exchange, chain, token)) .flatMap((token) -> authenticate(exchange, chain, token))
.onErrorResume(AuthenticationException.class, (e) -> this.authenticationFailureHandler .onErrorResume(AuthenticationException.class, (e) ->
.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)); this.authenticationFailureHandler.onAuthenticationFailure(new WebFilterExchange(exchange, chain), e)
);
// @formatter:on
} }
private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) { private Mono<Void> authenticate(ServerWebExchange exchange, WebFilterChain chain, Authentication token) {
@ -230,19 +235,30 @@ public class OAuth2AuthorizationCodeGrantWebFilter implements WebFilter {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
authenticationResult.getClientRegistration(), authenticationResult.getName(), authenticationResult.getClientRegistration(), authenticationResult.getName(),
authenticationResult.getAccessToken(), authenticationResult.getRefreshToken()); authenticationResult.getAccessToken(), authenticationResult.getRefreshToken());
// @formatter:off
return this.authenticationSuccessHandler.onAuthenticationSuccess(webFilterExchange, authentication) return this.authenticationSuccessHandler.onAuthenticationSuccess(webFilterExchange, authentication)
.then(ReactiveSecurityContextHolder.getContext().map(SecurityContext::getAuthentication) .then(ReactiveSecurityContextHolder.getContext()
.defaultIfEmpty(this.anonymousToken).flatMap((principal) -> this.authorizedClientRepository .map(SecurityContext::getAuthentication)
.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange()))); .defaultIfEmpty(this.anonymousToken)
.flatMap((principal) -> this.authorizedClientRepository
.saveAuthorizedClient(authorizedClient, principal, webFilterExchange.getExchange())
)
);
// @formatter:on
} }
private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) { private Mono<ServerWebExchangeMatcher.MatchResult> matchesAuthorizationResponse(ServerWebExchange exchange) {
return Mono.just(exchange).filter( // @formatter:off
(exch) -> OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams())) return Mono.just(exchange)
.filter((exch) ->
OAuth2AuthorizationResponseUtils.isAuthorizationResponse(exch.getRequest().getQueryParams())
)
.flatMap((exch) -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange) .flatMap((exch) -> this.authorizationRequestRepository.loadAuthorizationRequest(exchange)
.flatMap((authorizationRequest) -> matchesRedirectUri(exch.getRequest().getURI(), .flatMap((authorizationRequest) -> matchesRedirectUri(exch.getRequest().getURI(),
authorizationRequest.getRedirectUri()))) authorizationRequest.getRedirectUri()))
)
.switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch()); .switchIfEmpty(ServerWebExchangeMatcher.MatchResult.notMatch());
// @formatter:on
} }
private static Mono<ServerWebExchangeMatcher.MatchResult> matchesRedirectUri(URI authorizationResponseUri, private static Mono<ServerWebExchangeMatcher.MatchResult> matchesRedirectUri(URI authorizationResponseUri,

View File

@ -127,12 +127,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
@Override @Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
// @formatter:off
return this.authorizationRequestResolver.resolve(exchange) return this.authorizationRequestResolver.resolve(exchange)
.switchIfEmpty(chain.filter(exchange).then(Mono.empty())) .switchIfEmpty(chain.filter(exchange).then(Mono.empty()))
.onErrorResume(ClientAuthorizationRequiredException.class, .onErrorResume(ClientAuthorizationRequiredException.class,
(ex) -> this.requestCache.saveRequest(exchange).then( (ex) -> this.requestCache.saveRequest(exchange).then(
this.authorizationRequestResolver.resolve(exchange, ex.getClientRegistrationId()))) this.authorizationRequestResolver.resolve(exchange, ex.getClientRegistrationId()))
)
.flatMap((clientRegistration) -> sendRedirectForAuthorization(exchange, clientRegistration)); .flatMap((clientRegistration) -> sendRedirectForAuthorization(exchange, clientRegistration));
// @formatter:on
} }
private Mono<Void> sendRedirectForAuthorization(ServerWebExchange exchange, private Mono<Void> sendRedirectForAuthorization(ServerWebExchange exchange,
@ -143,8 +146,11 @@ public class OAuth2AuthorizationRequestRedirectWebFilter implements WebFilter {
saveAuthorizationRequest = this.authorizationRequestRepository saveAuthorizationRequest = this.authorizationRequestRepository
.saveAuthorizationRequest(authorizationRequest, exchange); .saveAuthorizationRequest(authorizationRequest, exchange);
} }
// @formatter:off
URI redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getAuthorizationRequestUri()) URI redirectUri = UriComponentsBuilder.fromUriString(authorizationRequest.getAuthorizationRequestUri())
.build(true).toUri(); .build(true)
.toUri();
// @formatter:on
return saveAuthorizationRequest return saveAuthorizationRequest
.then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri)); .then(this.authorizationRedirectStrategy.sendRedirect(exchange, redirectUri));
}); });

View File

@ -71,8 +71,14 @@ final class OAuth2AuthorizationResponseUtils {
} }
String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION); String errorDescription = request.getFirst(OAuth2ParameterNames.ERROR_DESCRIPTION);
String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI); String errorUri = request.getFirst(OAuth2ParameterNames.ERROR_URI);
return OAuth2AuthorizationResponse.error(errorCode).redirectUri(redirectUri).errorDescription(errorDescription) // @formatter:off
.errorUri(errorUri).state(state).build(); return OAuth2AuthorizationResponse.error(errorCode)
.redirectUri(redirectUri)
.errorDescription(errorDescription)
.errorUri(errorUri)
.state(state)
.build();
// @formatter:on
} }
} }

View File

@ -70,9 +70,11 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter implement
@Override @Override
public Mono<Authentication> convert(ServerWebExchange serverWebExchange) { public Mono<Authentication> convert(ServerWebExchange serverWebExchange) {
// @formatter:off
return this.authorizationRequestRepository.removeAuthorizationRequest(serverWebExchange) return this.authorizationRequestRepository.removeAuthorizationRequest(serverWebExchange)
.switchIfEmpty(oauth2AuthorizationException(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE)) .switchIfEmpty(oauth2AuthorizationException(AUTHORIZATION_REQUEST_NOT_FOUND_ERROR_CODE))
.flatMap((authorizationRequest) -> authenticationRequest(serverWebExchange, authorizationRequest)); .flatMap((authorizationRequest) -> authenticationRequest(serverWebExchange, authorizationRequest));
// @formatter:on
} }
private <T> Mono<T> oauth2AuthorizationException(String errorCode) { private <T> Mono<T> oauth2AuthorizationException(String errorCode) {
@ -84,13 +86,16 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter implement
private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(ServerWebExchange exchange, private Mono<OAuth2AuthorizationCodeAuthenticationToken> authenticationRequest(ServerWebExchange exchange,
OAuth2AuthorizationRequest authorizationRequest) { OAuth2AuthorizationRequest authorizationRequest) {
return Mono.just(authorizationRequest).map(OAuth2AuthorizationRequest::getAttributes).flatMap((attributes) -> { // @formatter:off
String id = (String) attributes.get(OAuth2ParameterNames.REGISTRATION_ID); return Mono.just(authorizationRequest)
if (id == null) { .map(OAuth2AuthorizationRequest::getAttributes).flatMap((attributes) -> {
return oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE); String id = (String) attributes.get(OAuth2ParameterNames.REGISTRATION_ID);
} if (id == null) {
return this.clientRegistrationRepository.findByRegistrationId(id); return oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE);
}).switchIfEmpty(oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE)) }
return this.clientRegistrationRepository.findByRegistrationId(id);
})
.switchIfEmpty(oauth2AuthorizationException(CLIENT_REGISTRATION_NOT_FOUND_ERROR_CODE))
.map((clientRegistration) -> { .map((clientRegistration) -> {
OAuth2AuthorizationResponse authorizationResponse = convertResponse(exchange); OAuth2AuthorizationResponse authorizationResponse = convertResponse(exchange);
OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken( OAuth2AuthorizationCodeAuthenticationToken authenticationRequest = new OAuth2AuthorizationCodeAuthenticationToken(
@ -98,6 +103,7 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverter implement
new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse)); new OAuth2AuthorizationExchange(authorizationRequest, authorizationResponse));
return authenticationRequest; return authenticationRequest;
}); });
// @formatter:on
} }
private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) { private static OAuth2AuthorizationResponse convertResponse(ServerWebExchange exchange) {

View File

@ -52,19 +52,23 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
if (state == null) { if (state == null) {
return Mono.empty(); return Mono.empty();
} }
// @formatter:off
return getStateToAuthorizationRequest(exchange) return getStateToAuthorizationRequest(exchange)
.filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state)) .filter((stateToAuthorizationRequest) -> stateToAuthorizationRequest.containsKey(state))
.map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state)); .map((stateToAuthorizationRequest) -> stateToAuthorizationRequest.get(state));
// @formatter:on
} }
@Override @Override
public Mono<Void> saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest, public Mono<Void> saveAuthorizationRequest(OAuth2AuthorizationRequest authorizationRequest,
ServerWebExchange exchange) { ServerWebExchange exchange) {
Assert.notNull(authorizationRequest, "authorizationRequest cannot be null"); Assert.notNull(authorizationRequest, "authorizationRequest cannot be null");
// @formatter:off
return saveStateToAuthorizationRequest(exchange) return saveStateToAuthorizationRequest(exchange)
.doOnNext((stateToAuthorizationRequest) -> stateToAuthorizationRequest .doOnNext((stateToAuthorizationRequest) -> stateToAuthorizationRequest
.put(authorizationRequest.getState(), authorizationRequest)) .put(authorizationRequest.getState(), authorizationRequest))
.then(); .then();
// @formatter:on
} }
@Override @Override
@ -73,29 +77,33 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
if (state == null) { if (state == null) {
return Mono.empty(); return Mono.empty();
} }
return exchange.getSession().map(WebSession::getAttributes).handle((sessionAttrs, sink) -> { // @formatter:off
Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest( return exchange.getSession()
sessionAttrs); .map(WebSession::getAttributes)
if (stateToAuthzRequest == null) { .handle((sessionAttrs, sink) -> {
sink.complete(); Map<String, OAuth2AuthorizationRequest> stateToAuthzRequest = sessionAttrsMapStateToAuthorizationRequest(
return; sessionAttrs);
} if (stateToAuthzRequest == null) {
OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state); sink.complete();
if (stateToAuthzRequest.isEmpty()) { return;
sessionAttrs.remove(this.sessionAttributeName); }
} OAuth2AuthorizationRequest removedValue = stateToAuthzRequest.remove(state);
else if (removedValue != null) { if (stateToAuthzRequest.isEmpty()) {
// gh-7327 Overwrite the existing Map to ensure the state is saved for sessionAttrs.remove(this.sessionAttributeName);
// distributed sessions }
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); else if (removedValue != null) {
} // gh-7327 Overwrite the existing Map to ensure the state is saved for
if (removedValue == null) { // distributed sessions
sink.complete(); sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
} }
else { if (removedValue == null) {
sink.next(removedValue); sink.complete();
} }
}); else {
sink.next(removedValue);
}
});
// @formatter:on
} }
/** /**
@ -115,22 +123,28 @@ public final class WebSessionOAuth2ServerAuthorizationRequestRepository
private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) { private Mono<Map<String, OAuth2AuthorizationRequest>> getStateToAuthorizationRequest(ServerWebExchange exchange) {
Assert.notNull(exchange, "exchange cannot be null"); Assert.notNull(exchange, "exchange cannot be null");
return getSessionAttributes(exchange).flatMap( // @formatter:off
(sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); return getSessionAttributes(exchange)
.flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
// @formatter:on
} }
private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) { private Mono<Map<String, OAuth2AuthorizationRequest>> saveStateToAuthorizationRequest(ServerWebExchange exchange) {
Assert.notNull(exchange, "exchange cannot be null"); Assert.notNull(exchange, "exchange cannot be null");
return getSessionAttributes(exchange).doOnNext((sessionAttrs) -> { // @formatter:off
Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName); return getSessionAttributes(exchange)
if (stateToAuthzRequest == null) { .doOnNext((sessionAttrs) -> {
stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>(); Object stateToAuthzRequest = sessionAttrs.get(this.sessionAttributeName);
} if (stateToAuthzRequest == null) {
// No matter stateToAuthzRequest was in session or not, we should always put stateToAuthzRequest = new HashMap<String, OAuth2AuthorizationRequest>();
// it into session again }
// in case of redis or hazelcast session. #6215 // No matter stateToAuthzRequest was in session or not, we should always put
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest); // it into session again
}).flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs))); // in case of redis or hazelcast session. #6215
sessionAttrs.put(this.sessionAttributeName, stateToAuthzRequest);
})
.flatMap((sessionAttrs) -> Mono.justOrEmpty(this.sessionAttrsMapStateToAuthorizationRequest(sessionAttrs)));
// @formatter:on
} }
private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest( private Map<String, OAuth2AuthorizationRequest> sessionAttrsMapStateToAuthorizationRequest(

View File

@ -50,8 +50,11 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S
Authentication principal, ServerWebExchange exchange) { Authentication principal, ServerWebExchange exchange) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(exchange, "exchange cannot be null"); Assert.notNull(exchange, "exchange cannot be null");
return exchange.getSession().map(this::getAuthorizedClients) // @formatter:off
return exchange.getSession()
.map(this::getAuthorizedClients)
.flatMap((clients) -> Mono.justOrEmpty((T) clients.get(clientRegistrationId))); .flatMap((clients) -> Mono.justOrEmpty((T) clients.get(clientRegistrationId)));
// @formatter:on
} }
@Override @Override
@ -59,11 +62,15 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S
ServerWebExchange exchange) { ServerWebExchange exchange) {
Assert.notNull(authorizedClient, "authorizedClient cannot be null"); Assert.notNull(authorizedClient, "authorizedClient cannot be null");
Assert.notNull(exchange, "exchange cannot be null"); Assert.notNull(exchange, "exchange cannot be null");
return exchange.getSession().doOnSuccess((session) -> { // @formatter:off
Map<String, OAuth2AuthorizedClient> authorizedClients = getAuthorizedClients(session); return exchange.getSession()
authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient); .doOnSuccess((session) -> {
session.getAttributes().put(this.sessionAttributeName, authorizedClients); Map<String, OAuth2AuthorizedClient> authorizedClients = getAuthorizedClients(session);
}).then(Mono.empty()); authorizedClients.put(authorizedClient.getClientRegistration().getRegistrationId(), authorizedClient);
session.getAttributes().put(this.sessionAttributeName, authorizedClients);
})
.then(Mono.empty());
// @formatter:on
} }
@Override @Override
@ -71,16 +78,20 @@ public final class WebSessionServerOAuth2AuthorizedClientRepository implements S
ServerWebExchange exchange) { ServerWebExchange exchange) {
Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty"); Assert.hasText(clientRegistrationId, "clientRegistrationId cannot be empty");
Assert.notNull(exchange, "exchange cannot be null"); Assert.notNull(exchange, "exchange cannot be null");
return exchange.getSession().doOnSuccess((session) -> { // @formatter:off
Map<String, OAuth2AuthorizedClient> authorizedClients = getAuthorizedClients(session); return exchange.getSession()
authorizedClients.remove(clientRegistrationId); .doOnSuccess((session) -> {
if (authorizedClients.isEmpty()) { Map<String, OAuth2AuthorizedClient> authorizedClients = getAuthorizedClients(session);
session.getAttributes().remove(this.sessionAttributeName); authorizedClients.remove(clientRegistrationId);
} if (authorizedClients.isEmpty()) {
else { session.getAttributes().remove(this.sessionAttributeName);
session.getAttributes().put(this.sessionAttributeName, authorizedClients); }
} else {
}).then(Mono.empty()); session.getAttributes().put(this.sessionAttributeName, authorizedClients);
}
})
.then(Mono.empty());
// @formatter:on
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")

View File

@ -61,9 +61,11 @@ public class OAuth2LoginAuthenticationWebFilter extends AuthenticationWebFilter
OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(authenticationResult.getPrincipal(), OAuth2AuthenticationToken result = new OAuth2AuthenticationToken(authenticationResult.getPrincipal(),
authenticationResult.getAuthorities(), authenticationResult.getAuthorities(),
authenticationResult.getClientRegistration().getRegistrationId()); authenticationResult.getClientRegistration().getRegistrationId());
// @formatter:off
return this.authorizedClientRepository return this.authorizedClientRepository
.saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange()) .saveAuthorizedClient(authorizedClient, authenticationResult, webFilterExchange.getExchange())
.then(super.onAuthenticationSuccess(result, webFilterExchange)); .then(super.onAuthenticationSuccess(result, webFilterExchange));
// @formatter:on
} }
} }

View File

@ -61,22 +61,31 @@ public class AuthorizationCodeOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() {
ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientCredentialsClient).principal(this.principal).build(); .withClientRegistration(clientCredentialsClient).principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@Test @Test
public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal).build(); .withAuthorizedClient(this.authorizedClient).principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@Test @Test
public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration).principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)); .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext));
} }

View File

@ -61,22 +61,34 @@ public class AuthorizationCodeReactiveOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() { public void authorizeWhenNotAuthorizationCodeThenUnableToAuthorize() {
ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build(); ClientRegistration clientCredentialsClient = TestClientRegistrations.clientCredentials().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientCredentialsClient).principal(this.principal).build(); .withClientRegistration(clientCredentialsClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@Test @Test
public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() { public void authorizeWhenAuthorizationCodeAndAuthorizedThenNotAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal).build(); .withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@Test @Test
public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() { public void authorizeWhenAuthorizationCodeAndNotAuthorizedThenAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()); .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block());
} }

View File

@ -108,58 +108,78 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
@Test @Test
public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() { public void constructorWhenClientRegistrationRepositoryIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService)) assertThatIllegalArgumentException()
.isThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(null, this.authorizedClientService))
.withMessage("clientRegistrationRepository cannot be null"); .withMessage("clientRegistrationRepository cannot be null");
// @formatter:on
} }
@Test @Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null)) assertThatIllegalArgumentException()
.isThrownBy(() -> new AuthorizedClientServiceOAuth2AuthorizedClientManager(this.clientRegistrationRepository, null))
.withMessage("authorizedClientService cannot be null"); .withMessage("authorizedClientService cannot be null");
// @formatter:on
} }
@Test @Test
public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() { public void setAuthorizedClientProviderWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null)) .isThrownBy(() -> this.authorizedClientManager.setAuthorizedClientProvider(null))
.withMessage("authorizedClientProvider cannot be null"); .withMessage("authorizedClientProvider cannot be null");
// @formatter:on
} }
@Test @Test
public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() { public void setContextAttributesMapperWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null)) .isThrownBy(() -> this.authorizedClientManager.setContextAttributesMapper(null))
.withMessage("contextAttributesMapper cannot be null"); .withMessage("contextAttributesMapper cannot be null");
// @formatter:on
} }
@Test @Test
public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() { public void setAuthorizationSuccessHandlerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null)) .isThrownBy(() -> this.authorizedClientManager.setAuthorizationSuccessHandler(null))
.withMessage("authorizationSuccessHandler cannot be null"); .withMessage("authorizationSuccessHandler cannot be null");
// @formatter:on
} }
@Test @Test
public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() { public void setAuthorizationFailureHandlerWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null)) .isThrownBy(() -> this.authorizedClientManager.setAuthorizationFailureHandler(null))
.withMessage("authorizationFailureHandler cannot be null"); .withMessage("authorizationFailureHandler cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() { public void authorizeWhenRequestIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.authorize(null))
.withMessage("authorizeRequest cannot be null"); .withMessage("authorizeRequest cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId("invalid-registration-id").principal(this.principal).build(); .withClientRegistrationId("invalid-registration-id")
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest))
.principal(this.principal).build();
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest))
.withMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); .withMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
// @formatter:on
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -189,9 +209,12 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
.willReturn(this.clientRegistration); .willReturn(this.clientRegistration);
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(this.authorizedClient); .willReturn(this.authorizedClient);
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -216,9 +239,12 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(reauthorizedClient); .willReturn(reauthorizedClient);
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -324,11 +350,15 @@ public class AuthorizedClientServiceOAuth2AuthorizedClientManagerTests {
new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId()); new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willThrow(authorizationException); .willThrow(authorizationException);
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) // @formatter:off
.principal(this.principal).build(); OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest
.withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
assertThatExceptionOfType(ClientAuthorizationException.class) assertThatExceptionOfType(ClientAuthorizationException.class)
.isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException); .isEqualTo(authorizationException);
// @formatter:on
verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal), verify(this.authorizationFailureHandler).onAuthorizationFailure(eq(authorizationException), eq(this.principal),
any()); any());
verifyNoInteractions(this.authorizedClientService); verifyNoInteractions(this.authorizedClientService);

View File

@ -170,9 +170,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
given(this.authorizedClientProvider.authorize(any())).willReturn(Mono.empty()); given(this.authorizedClientProvider.authorize(any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
StepVerifier.create(authorizedClient).verifyComplete(); StepVerifier.create(authorizedClient).verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -217,9 +220,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(Mono.just(this.authorizedClient)); .willReturn(Mono.just(this.authorizedClient));
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
PublisherProbe<Void> authorizationSuccessHandlerProbe = PublisherProbe.empty(); PublisherProbe<Void> authorizationSuccessHandlerProbe = PublisherProbe.empty();
this.authorizedClientManager.setAuthorizationSuccessHandler( this.authorizedClientManager.setAuthorizationSuccessHandler(
(client, principal, attributes) -> authorizationSuccessHandlerProbe.mono()); (client, principal, attributes) -> authorizationSuccessHandlerProbe.mono());
@ -241,9 +247,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
ClientAuthorizationException exception = new ClientAuthorizationException( ClientAuthorizationException exception = new ClientAuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null), new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null),
this.clientRegistration.getRegistrationId()); this.clientRegistration.getRegistrationId());
@ -269,9 +278,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
ClientAuthorizationException exception = new ClientAuthorizationException( ClientAuthorizationException exception = new ClientAuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null), new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null),
this.clientRegistration.getRegistrationId()); this.clientRegistration.getRegistrationId());
@ -297,9 +309,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
ClientAuthorizationException exception = new ClientAuthorizationException( ClientAuthorizationException exception = new ClientAuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null), new OAuth2Error(OAuth2ErrorCodes.SERVER_ERROR, null, null),
this.clientRegistration.getRegistrationId()); this.clientRegistration.getRegistrationId());
@ -323,9 +338,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizationException exception = new OAuth2AuthorizationException( OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
@ -348,9 +366,12 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty()); given(this.authorizedClientService.loadAuthorizedClient(any(), any())).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizationException exception = new OAuth2AuthorizationException( OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null)); new OAuth2Error(OAuth2ErrorCodes.INVALID_GRANT, null, null));
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
@ -387,7 +408,11 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal)
.build(); .build();
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
StepVerifier.create(authorizedClient).expectNext(reauthorizedClient).verifyComplete(); // @formatter:off
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
// @formatter:on
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
@ -403,8 +428,11 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
@Test @Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).willReturn(Mono.empty()); given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))).willReturn(Mono.empty());
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).build(); .principal(this.principal)
.build();
// @formatter:on
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
StepVerifier.create(authorizedClient).expectNext(this.authorizedClient).verifyComplete(); StepVerifier.create(authorizedClient).expectNext(this.authorizedClient).verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -424,8 +452,11 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(Mono.just(reauthorizedClient)); .willReturn(Mono.just(reauthorizedClient));
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).build(); .principal(this.principal)
.build();
// @formatter:on
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
StepVerifier.create(authorizedClient).expectNext(reauthorizedClient).verifyComplete(); StepVerifier.create(authorizedClient).expectNext(reauthorizedClient).verifyComplete();
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
@ -451,7 +482,11 @@ public class AuthorizedClientServiceReactiveOAuth2AuthorizedClientManagerTests {
this.authorizedClientManager.setContextAttributesMapper( this.authorizedClientManager.setContextAttributesMapper(
new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); new AuthorizedClientServiceReactiveOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); Mono<OAuth2AuthorizedClient> authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
StepVerifier.create(authorizedClient).expectNext(reauthorizedClient).verifyComplete(); // @formatter:off
StepVerifier.create(authorizedClient)
.expectNext(reauthorizedClient)
.verifyComplete();
// @formatter:on
verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal)); verify(this.authorizedClientService).saveAuthorizedClient(eq(reauthorizedClient), eq(this.principal));
this.saveAuthorizedClientProbe.assertWasSubscribed(); this.saveAuthorizedClientProbe.assertWasSubscribed();
verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any()); verify(this.authorizedClientService, never()).removeAuthorizedClient(any(), any());

View File

@ -65,41 +65,58 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.isInstanceOf(IllegalArgumentException.class).withMessage("accessTokenResponseClient cannot be null"); .isInstanceOf(IllegalArgumentException.class).withMessage("accessTokenResponseClient cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null))
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(this.principal).build(); .withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -107,8 +124,12 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests {
public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -125,8 +146,12 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken); this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -137,8 +162,12 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests {
public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes()); this.principal.getName(), TestOAuth2AccessTokens.noScopes());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -157,8 +186,12 @@ public class ClientCredentialsOAuth2AuthorizedClientProviderTests {
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());

View File

@ -66,41 +66,58 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.withMessage("accessTokenResponseClient cannot be null"); .withMessage("accessTokenResponseClient cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() { public void authorizeWhenNotClientCredentialsThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(this.principal).build(); .withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -108,8 +125,12 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() { public void authorizeWhenClientCredentialsAndNotAuthorizedThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -126,8 +147,12 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken); this.principal.getName(), accessToken);
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -138,8 +163,12 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() { public void authorizeWhenClientCredentialsAndTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes()); this.principal.getName(), TestOAuth2AccessTokens.noScopes());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -158,8 +187,12 @@ public class ClientCredentialsReactiveOAuth2AuthorizedClientProviderTests {
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext)
.block(); .block();
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);

View File

@ -67,9 +67,11 @@ public class InMemoryOAuth2AuthorizedClientServiceTests {
@Test @Test
public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentException() { public void constructorWhenAuthorizedClientsIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository, null)) .isThrownBy(() -> new InMemoryOAuth2AuthorizedClientService(this.clientRegistrationRepository, null))
.withMessage("authorizedClients cannot be empty"); .withMessage("authorizedClients cannot be empty");
// @formatter:on
} }
@Test @Test

View File

@ -59,13 +59,21 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(), OAuth2AccessToken accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "token", Instant.now(),
Instant.now().plus(Duration.ofDays(1))); Instant.now().plus(Duration.ofDays(1)));
// @formatter:off
private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).scope("read:user") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://github.com/login/oauth/authorize") .authorizationUri("https://github.com/login/oauth/authorize")
.tokenUri("https://github.com/login/oauth/access_token").userInfoUri("https://api.github.com/user") .tokenUri("https://github.com/login/oauth/access_token")
.userNameAttributeName("id").clientName("GitHub").clientId("clientId").clientSecret("clientSecret").build(); .userInfoUri("https://api.github.com/user")
.userNameAttributeName("id")
.clientName("GitHub")
.clientId("clientId")
.clientSecret("clientSecret")
.build();
// @formatter:on
@Before @Before
public void setup() { public void setup() {
@ -83,29 +91,37 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
@Test @Test
public void loadAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { public void loadAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() {
this.clientRegistrationId = null; this.clientRegistrationId = null;
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void loadAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() { public void loadAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() {
this.clientRegistrationId = ""; this.clientRegistrationId = "";
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void loadAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() { public void loadAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() {
this.principalName = null; this.principalName = null;
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void loadAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() { public void loadAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() {
this.principalName = ""; this.principalName = "";
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
@ -132,17 +148,23 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken); this.principalName, this.accessToken);
// @formatter:off
Mono<OAuth2AuthorizedClient> saveAndLoad = this.authorizedClientService Mono<OAuth2AuthorizedClient> saveAndLoad = this.authorizedClientService
.saveAuthorizedClient(authorizedClient, this.principal) .saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndLoad).expectNext(authorizedClient).verifyComplete(); StepVerifier.create(saveAndLoad)
.expectNext(authorizedClient)
.verifyComplete();
// @formatter:on
} }
@Test @Test
public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() { public void saveAuthorizedClientWhenAuthorizedClientNullThenIllegalArgumentException() {
OAuth2AuthorizedClient authorizedClient = null; OAuth2AuthorizedClient authorizedClient = null;
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)); .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal));
// @formatter:on
} }
@Test @Test
@ -150,36 +172,46 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken); this.principalName, this.accessToken);
this.principal = null; this.principal = null;
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)); .isThrownBy(() -> this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal));
// @formatter:on
} }
@Test @Test
public void removeAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() { public void removeAuthorizedClientWhenClientRegistrationIdNullThenIllegalArgumentException() {
this.clientRegistrationId = null; this.clientRegistrationId = null;
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void removeAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() { public void removeAuthorizedClientWhenClientRegistrationIdEmptyThenIllegalArgumentException() {
this.clientRegistrationId = ""; this.clientRegistrationId = "";
assertThatIllegalArgumentException().isThrownBy( // @formatter:off
() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void removeAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() { public void removeAuthorizedClientWhenPrincipalNameNullThenIllegalArgumentException() {
this.principalName = null; this.principalName = null;
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientService // @formatter:off
.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
public void removeAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() { public void removeAuthorizedClientWhenPrincipalNameEmptyThenIllegalArgumentException() {
this.principalName = ""; this.principalName = "";
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientService // @formatter:off
.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, this.principalName));
// @formatter:on
} }
@Test @Test
@ -188,10 +220,14 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
.willReturn(Mono.empty()); .willReturn(Mono.empty());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken); this.principalName, this.accessToken);
Mono<Void> saveAndDeleteAndLoad = this.authorizedClientService // @formatter:off
.saveAuthorizedClient(authorizedClient, this.principal).then(this.authorizedClientService Mono<Void> saveAndDeleteAndLoad = this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)
.removeAuthorizedClient(this.clientRegistrationId, this.principalName)); .then(this.authorizedClientService
StepVerifier.create(saveAndDeleteAndLoad).verifyComplete(); .removeAuthorizedClient(this.clientRegistrationId, this.principalName)
);
StepVerifier.create(saveAndDeleteAndLoad)
.verifyComplete();
// @formatter:on
} }
@Test @Test
@ -200,12 +236,14 @@ public class InMemoryReactiveOAuth2AuthorizedClientServiceTests {
.willReturn(Mono.just(this.clientRegistration)); .willReturn(Mono.just(this.clientRegistration));
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principalName, this.accessToken); this.principalName, this.accessToken);
Mono<OAuth2AuthorizedClient> saveAndDeleteAndLoad = this.authorizedClientService // @formatter:off
.saveAuthorizedClient(authorizedClient, this.principal) Mono<OAuth2AuthorizedClient> saveAndDeleteAndLoad = this.authorizedClientService.saveAuthorizedClient(authorizedClient, this.principal)
.then(this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId, .then(this.authorizedClientService.removeAuthorizedClient(this.clientRegistrationId,
this.principalName)) this.principalName))
.then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName)); .then(this.authorizedClientService.loadAuthorizedClient(this.clientRegistrationId, this.principalName));
StepVerifier.create(saveAndDeleteAndLoad).verifyComplete(); StepVerifier.create(saveAndDeleteAndLoad)
.verifyComplete();
// @formatter:on
} }
} }

View File

@ -104,9 +104,11 @@ public class JdbcOAuth2AuthorizedClientServiceTests {
@Test @Test
public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() { public void constructorWhenJdbcOperationsIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository)) .isThrownBy(() -> new JdbcOAuth2AuthorizedClientService(null, this.clientRegistrationRepository))
.withMessage("jdbcOperations cannot be null"); .withMessage("jdbcOperations cannot be null");
// @formatter:on
} }
@Test @Test
@ -118,31 +120,39 @@ public class JdbcOAuth2AuthorizedClientServiceTests {
@Test @Test
public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() { public void setAuthorizedClientRowMapperWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null)) .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientRowMapper(null))
.withMessage("authorizedClientRowMapper cannot be null"); .withMessage("authorizedClientRowMapper cannot be null");
// @formatter:on
} }
@Test @Test
public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() { public void setAuthorizedClientParametersMapperWhenNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null)) .isThrownBy(() -> this.authorizedClientService.setAuthorizedClientParametersMapper(null))
.withMessage("authorizedClientParametersMapper cannot be null"); .withMessage("authorizedClientParametersMapper cannot be null");
// @formatter:on
} }
@Test @Test
public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() { public void loadAuthorizedClientWhenClientRegistrationIdIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName")) .isThrownBy(() -> this.authorizedClientService.loadAuthorizedClient(null, "principalName"))
.withMessage("clientRegistrationId cannot be empty"); .withMessage("clientRegistrationId cannot be empty");
// @formatter:on
} }
@Test @Test
public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() { public void loadAuthorizedClientWhenPrincipalNameIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientService .isThrownBy(() -> this.authorizedClientService
.loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null)) .loadAuthorizedClient(this.clientRegistration.getRegistrationId(), null))
.withMessage("principalName cannot be empty"); .withMessage("principalName cannot be empty");
// @formatter:on
} }
@Test @Test
@ -351,8 +361,14 @@ public class JdbcOAuth2AuthorizedClientServiceTests {
} }
private static EmbeddedDatabase createDb(String schema) { private static EmbeddedDatabase createDb(String schema) {
return new EmbeddedDatabaseBuilder().generateUniqueName(true).setType(EmbeddedDatabaseType.HSQL) // @formatter:off
.setScriptEncoding("UTF-8").addScript(schema).build(); return new EmbeddedDatabaseBuilder()
.generateUniqueName(true)
.setType(EmbeddedDatabaseType.HSQL)
.setScriptEncoding("UTF-8")
.addScript(schema)
.build();
// @formatter:on
} }
private static Authentication createPrincipal() { private static Authentication createPrincipal() {

View File

@ -73,11 +73,16 @@ public class OAuth2AuthorizationContextTests {
@Test @Test
public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() { public void withAuthorizedClientWhenAllValuesProvidedThenAllValuesAreSet() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal).attributes((attributes) -> { .withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.attributes((attributes) -> {
attributes.put("attribute1", "value1"); attributes.put("attribute1", "value1");
attributes.put("attribute2", "value2"); attributes.put("attribute2", "value2");
}).build(); })
.build();
// @formatter:on
assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizationContext.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient); assertThat(authorizationContext.getAuthorizedClient()).isSameAs(this.authorizedClient);
assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal); assertThat(authorizationContext.getPrincipal()).isSameAs(this.principal);

View File

@ -89,11 +89,15 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@Test @Test
public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() {
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
.authorizationCode().build(); .authorizationCode()
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientRegistration().build()).principal(this.principal)
.build(); .build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientRegistration().build())
.principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext)); .isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext));
} }
@ -107,8 +111,12 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
TestClientRegistrations.clientRegistration().build(), this.principal.getName(), expiredAccessToken(), TestClientRegistrations.clientRegistration().build(), this.principal.getName(), expiredAccessToken(),
TestOAuth2RefreshTokens.refreshToken()); TestOAuth2RefreshTokens.refreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient).isNotNull(); assertThat(reauthorizedClient).isNotNull();
verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class));
@ -120,9 +128,12 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
.clientCredentials( .clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.build(); .build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientCredentials().build()).principal(this.principal) .withClientRegistration(TestClientRegistrations.clientCredentials().build())
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class));
@ -130,13 +141,17 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@Test @Test
public void buildWhenPasswordProviderThenProviderAuthorizes() { public void buildWhenPasswordProviderThenProviderAuthorizes() {
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
.password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient))
.build(); .build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.password().build()).principal(this.principal) .withClientRegistration(TestClientRegistrations.password().build())
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class)); verify(this.accessTokenClient).exchange(any(RequestEntity.class), eq(OAuth2AccessTokenResponse.class));
@ -154,8 +169,12 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
.build(); .build();
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
// authorization_code // authorization_code
// @formatter:off
OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(this.principal).build(); .withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext)); .isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext));
// refresh_token // refresh_token
@ -168,18 +187,25 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class), verify(this.accessTokenClient, times(1)).exchange(any(RequestEntity.class),
eq(OAuth2AccessTokenResponse.class)); eq(OAuth2AccessTokenResponse.class));
// client_credentials // client_credentials
// @formatter:off
OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientCredentials().build()).principal(this.principal) .withClientRegistration(TestClientRegistrations.clientCredentials().build())
.principal(this.principal)
.build(); .build();
// @formatter:on
authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext); authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class), verify(this.accessTokenClient, times(2)).exchange(any(RequestEntity.class),
eq(OAuth2AccessTokenResponse.class)); eq(OAuth2AccessTokenResponse.class));
// password // password
// @formatter:off
OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.password().build()).principal(this.principal) .withClientRegistration(TestClientRegistrations.password().build())
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
authorizedClient = authorizedClientProvider.authorize(passwordContext); authorizedClient = authorizedClientProvider.authorize(passwordContext);
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
verify(this.accessTokenClient, times(3)).exchange(any(RequestEntity.class), verify(this.accessTokenClient, times(3)).exchange(any(RequestEntity.class),
@ -189,11 +215,15 @@ public class OAuth2AuthorizedClientProviderBuilderTests {
@Test @Test
public void buildWhenCustomProviderThenProviderCalled() { public void buildWhenCustomProviderThenProviderCalled() {
OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class); OAuth2AuthorizedClientProvider customProvider = mock(OAuth2AuthorizedClientProvider.class);
// @formatter:off
OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder() OAuth2AuthorizedClientProvider authorizedClientProvider = OAuth2AuthorizedClientProviderBuilder.builder()
.provider(customProvider).build(); .provider(customProvider)
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientRegistration().build()).principal(this.principal)
.build(); .build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(TestClientRegistrations.clientRegistration().build())
.principal(this.principal)
.build();
// @formatter:on
authorizedClientProvider.authorize(authorizationContext); authorizedClientProvider.authorize(authorizationContext);
verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); verify(customProvider).authorize(any(OAuth2AuthorizationContext.class));
} }

View File

@ -72,52 +72,75 @@ public class PasswordOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null))
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotPasswordThenUnableToAuthorize() { public void authorizeWhenNotPasswordThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(this.principal).build(); .withClientRegistration(clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@Test @Test
public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null)
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@Test @Test
public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null).build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -125,10 +148,14 @@ public class PasswordOAuth2AuthorizedClientProviderTests {
public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -145,11 +172,14 @@ public class PasswordOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken); // without refresh token this.principal.getName(), accessToken); // without refresh token
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -166,11 +196,14 @@ public class PasswordOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with
// refresh // refresh
// token // token
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -190,11 +223,14 @@ public class PasswordOAuth2AuthorizedClientProviderTests {
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());

View File

@ -73,52 +73,75 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotPasswordThenUnableToAuthorize() { public void authorizeWhenNotPasswordThenUnableToAuthorize() {
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(clientRegistration).principal(this.principal).build(); .withClientRegistration(clientRegistration)
.principal(this.principal).
build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@Test @Test
public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedAndEmptyUsernameThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null) .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, null)
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@Test @Test
public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedAndEmptyPasswordThenUnableToAuthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null).build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, null)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -126,10 +149,14 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests {
public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() { public void authorizeWhenPasswordAndNotAuthorizedThenAuthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal) .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); OAuth2AuthorizedClient authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -146,11 +173,14 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken); // without refresh token this.principal.getName(), accessToken); // without refresh token
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block(); authorizedClient = this.authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(authorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(authorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -167,11 +197,14 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests {
this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with this.principal.getName(), accessToken, TestOAuth2RefreshTokens.refreshToken()); // with
// refresh // refresh
// token // token
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -191,11 +224,14 @@ public class PasswordReactiveOAuth2AuthorizedClientProviderTests {
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse().build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient) .withAuthorizedClient(authorizedClient)
.attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").principal(this.principal) .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.principal(this.principal)
.build(); .build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext) OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext)
.block(); .block();
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);

View File

@ -80,10 +80,16 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
@Test @Test
public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() { public void buildWhenAuthorizationCodeProviderThenProviderAuthorizes() {
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().authorizationCode().build(); .builder()
.authorizationCode()
.build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistrationBuilder.build()).principal(this.principal).build(); .withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block()); .isThrownBy(() -> authorizedClientProvider.authorize(authorizationContext).block());
} }
@ -93,12 +99,20 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().refreshToken().build(); .builder()
.refreshToken()
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(), OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistrationBuilder.build(),
this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken()); this.principal.getName(), expiredAccessToken(), TestOAuth2RefreshTokens.refreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); OAuth2AuthorizedClient reauthorizedClient = authorizedClientProvider.authorize(authorizationContext).block();
assertThat(reauthorizedClient).isNotNull(); assertThat(reauthorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1); assertThat(this.server.getRequestCount()).isEqualTo(1);
@ -112,12 +126,17 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().clientCredentials().build(); .builder()
.clientCredentials()
.build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistrationBuilder .withClientRegistration(this.clientRegistrationBuilder
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build())
.principal(this.principal).build(); .principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block();
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1); assertThat(this.server.getRequestCount()).isEqualTo(1);
@ -133,12 +152,17 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().password().build(); .builder().password().build();
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration( .withClientRegistration(
this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build())
.principal(this.principal).attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .principal(this.principal)
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext).block(); .attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
OAuth2AuthorizedClient authorizedClient = authorizedClientProvider.authorize(authorizationContext)
.block();
// @formatter:on
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(1); assertThat(this.server.getRequestCount()).isEqualTo(1);
RecordedRequest recordedRequest = this.server.takeRequest(); RecordedRequest recordedRequest = this.server.takeRequest();
@ -156,8 +180,12 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().authorizationCode().refreshToken().clientCredentials().password().build(); .builder().authorizationCode().refreshToken().clientCredentials().password().build();
// authorization_code // authorization_code
// @formatter:off
OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationCodeContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistrationBuilder.build()).principal(this.principal).build(); .withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationRequiredException.class) assertThatExceptionOfType(ClientAuthorizationRequiredException.class)
.isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block()); .isThrownBy(() -> authorizedClientProvider.authorize(authorizationCodeContext).block());
// refresh_token // refresh_token
@ -172,10 +200,13 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
String formParameters = recordedRequest.getBody().readUtf8(); String formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=refresh_token"); assertThat(formParameters).contains("grant_type=refresh_token");
// client_credentials // client_credentials
// @formatter:off
OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext OAuth2AuthorizationContext clientCredentialsContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistrationBuilder .withClientRegistration(this.clientRegistrationBuilder
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build()) .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).build())
.principal(this.principal).build(); .principal(this.principal)
.build();
// @formatter:on
authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext).block(); authorizedClient = authorizedClientProvider.authorize(clientCredentialsContext).block();
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(2); assertThat(this.server.getRequestCount()).isEqualTo(2);
@ -183,11 +214,15 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
formParameters = recordedRequest.getBody().readUtf8(); formParameters = recordedRequest.getBody().readUtf8();
assertThat(formParameters).contains("grant_type=client_credentials"); assertThat(formParameters).contains("grant_type=client_credentials");
// password // password
// @formatter:off
OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext OAuth2AuthorizationContext passwordContext = OAuth2AuthorizationContext
.withClientRegistration( .withClientRegistration(
this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build()) this.clientRegistrationBuilder.authorizationGrantType(AuthorizationGrantType.PASSWORD).build())
.principal(this.principal).attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username") .principal(this.principal)
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password").build(); .attribute(OAuth2AuthorizationContext.USERNAME_ATTRIBUTE_NAME, "username")
.attribute(OAuth2AuthorizationContext.PASSWORD_ATTRIBUTE_NAME, "password")
.build();
// @formatter:on
authorizedClient = authorizedClientProvider.authorize(passwordContext).block(); authorizedClient = authorizedClientProvider.authorize(passwordContext).block();
assertThat(authorizedClient).isNotNull(); assertThat(authorizedClient).isNotNull();
assertThat(this.server.getRequestCount()).isEqualTo(3); assertThat(this.server.getRequestCount()).isEqualTo(3);
@ -200,10 +235,16 @@ public class ReactiveOAuth2AuthorizedClientProviderBuilderTests {
public void buildWhenCustomProviderThenProviderCalled() { public void buildWhenCustomProviderThenProviderCalled() {
ReactiveOAuth2AuthorizedClientProvider customProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class); ReactiveOAuth2AuthorizedClientProvider customProvider = mock(ReactiveOAuth2AuthorizedClientProvider.class);
given(customProvider.authorize(any())).willReturn(Mono.empty()); given(customProvider.authorize(any())).willReturn(Mono.empty());
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().provider(customProvider).build(); .builder()
.provider(customProvider)
.build();
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistrationBuilder.build()).principal(this.principal).build(); .withClientRegistration(this.clientRegistrationBuilder.build())
.principal(this.principal)
.build();
// @formatter:on
authorizedClientProvider.authorize(authorizationContext).block(); authorizedClientProvider.authorize(authorizationContext).block();
verify(customProvider).authorize(any(OAuth2AuthorizationContext.class)); verify(customProvider).authorize(any(OAuth2AuthorizationContext.class));
} }

View File

@ -78,40 +78,57 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() { public void setAccessTokenResponseClientWhenClientIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null)) .isThrownBy(() -> this.authorizedClientProvider.setAccessTokenResponseClient(null))
.withMessage("accessTokenResponseClient cannot be null"); .withMessage("accessTokenResponseClient cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null))
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { public void authorizeWhenNotAuthorizedThenUnableToReauthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -119,8 +136,12 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), this.authorizedClient.getAccessToken()); this.principal.getName(), this.authorizedClient.getAccessToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -128,8 +149,12 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext)).isNull();
} }
@ -149,8 +174,12 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
// Shorten the lifespan of the access token by 90 seconds, which will ultimately // Shorten the lifespan of the access token by 90 seconds, which will ultimately
// force it to expire on the client // force it to expire on the client
this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90)); this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(90));
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -160,11 +189,19 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() { public void authorizeWhenAuthorizedAndAccessTokenExpiredThenReauthorize() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() // @formatter:off
.refreshToken("new-refresh-token").build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal).build(); .withAuthorizedClient(this.authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext); OAuth2AuthorizedClient reauthorizedClient = this.authorizedClientProvider.authorize(authorizationContext);
assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration); assertThat(reauthorizedClient.getClientRegistration()).isSameAs(this.clientRegistration);
assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName()); assertThat(reauthorizedClient.getPrincipalName()).isEqualTo(this.principal.getName());
@ -174,13 +211,21 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() { public void authorizeWhenAuthorizedAndRequestScopeProvidedThenScopeRequested() {
OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses.accessTokenResponse() // @formatter:off
.refreshToken("new-refresh-token").build(); OAuth2AccessTokenResponse accessTokenResponse = TestOAuth2AccessTokenResponses
.accessTokenResponse()
.refreshToken("new-refresh-token")
.build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
String[] requestScope = new String[] { "read", "write" }; String[] requestScope = new String[] { "read", "write" };
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal) .withAuthorizedClient(this.authorizedClient)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope).build(); .principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext); this.authorizedClientProvider.authorize(authorizationContext);
ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor = ArgumentCaptor ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor = ArgumentCaptor
.forClass(OAuth2RefreshTokenGrantRequest.class); .forClass(OAuth2RefreshTokenGrantRequest.class);
@ -192,9 +237,13 @@ public class RefreshTokenOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() {
String invalidRequestScope = "read write"; String invalidRequestScope = "read write";
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal) .withAuthorizedClient(this.authorizedClient)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope).build(); .principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope)
.build();
// @formatter:on
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext)) .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext))
.withMessageStartingWith("The context attribute must be of type String[] '" .withMessageStartingWith("The context attribute must be of type String[] '"

View File

@ -86,33 +86,48 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(null))
.withMessage("clockSkew cannot be null"); .withMessage("clockSkew cannot be null");
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1))) .isThrownBy(() -> this.authorizedClientProvider.setClockSkew(Duration.ofSeconds(-1)))
.withMessage("clockSkew must be >= 0"); .withMessage("clockSkew must be >= 0");
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.setClock(null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.setClock(null))
.withMessage("clock cannot be null"); .withMessage("clock cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() { public void authorizeWhenContextIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientProvider.authorize(null).block()) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(null).block())
.withMessage("context cannot be null"); .withMessage("context cannot be null");
// @formatter:on
} }
@Test @Test
public void authorizeWhenNotAuthorizedThenUnableToReauthorize() { public void authorizeWhenNotAuthorizedThenUnableToReauthorize() {
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withClientRegistration(this.clientRegistration).principal(this.principal).build(); .withClientRegistration(this.clientRegistration)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -120,8 +135,12 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() { public void authorizeWhenAuthorizedAndRefreshTokenIsNullThenUnableToReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), this.authorizedClient.getAccessToken()); this.principal.getName(), this.authorizedClient.getAccessToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -129,8 +148,12 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() { public void authorizeWhenAuthorizedAndAccessTokenNotExpiredThenNotReauthorize() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.clientRegistration,
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), this.authorizedClient.getRefreshToken());
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(authorizedClient).principal(this.principal).build(); .withAuthorizedClient(authorizedClient)
.principal(this.principal)
.build();
// @formatter:on
assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull(); assertThat(this.authorizedClientProvider.authorize(authorizationContext).block()).isNull();
} }
@ -181,9 +204,13 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
.refreshToken("new-refresh-token").build(); .refreshToken("new-refresh-token").build();
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
String[] requestScope = new String[] { "read", "write" }; String[] requestScope = new String[] { "read", "write" };
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal) .withAuthorizedClient(this.authorizedClient)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope).build(); .principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, requestScope)
.build();
// @formatter:on
this.authorizedClientProvider.authorize(authorizationContext).block(); this.authorizedClientProvider.authorize(authorizationContext).block();
ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor = ArgumentCaptor ArgumentCaptor<OAuth2RefreshTokenGrantRequest> refreshTokenGrantRequestArgCaptor = ArgumentCaptor
.forClass(OAuth2RefreshTokenGrantRequest.class); .forClass(OAuth2RefreshTokenGrantRequest.class);
@ -195,9 +222,13 @@ public class RefreshTokenReactiveOAuth2AuthorizedClientProviderTests {
@Test @Test
public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() { public void authorizeWhenAuthorizedAndInvalidRequestScopeProvidedThenThrowIllegalArgumentException() {
String invalidRequestScope = "read write"; String invalidRequestScope = "read write";
// @formatter:off
OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext OAuth2AuthorizationContext authorizationContext = OAuth2AuthorizationContext
.withAuthorizedClient(this.authorizedClient).principal(this.principal) .withAuthorizedClient(this.authorizedClient)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope).build(); .principal(this.principal)
.attribute(OAuth2AuthorizationContext.REQUEST_SCOPE_ATTRIBUTE_NAME, invalidRequestScope)
.build();
// @formatter:on
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block()) .isThrownBy(() -> this.authorizedClientProvider.authorize(authorizationContext).block())
.withMessageStartingWith("The context attribute must be of type String[] '" .withMessageStartingWith("The context attribute must be of type String[] '"

View File

@ -214,9 +214,15 @@ public class OAuth2LoginAuthenticationProviderTests {
Map<String, Object> additionalParameters = new HashMap<>(); Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2"); additionalParameters.put("param2", "value2");
return OAuth2AccessTokenResponse.withToken("access-token-1234").tokenType(OAuth2AccessToken.TokenType.BEARER) // @formatter:off
.expiresIn(expiresAt.getEpochSecond()).scopes(scopes).refreshToken("refresh-token-1234") return OAuth2AccessTokenResponse.withToken("access-token-1234")
.additionalParameters(additionalParameters).build(); .tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(expiresAt.getEpochSecond())
.scopes(scopes)
.refreshToken("refresh-token-1234")
.additionalParameters(additionalParameters)
.build();
// @formatter:on
} }
} }

View File

@ -124,7 +124,11 @@ public class OAuth2LoginReactiveAuthenticationManagerTests {
@Test @Test
public void authenticationWhenErrorThenOAuth2AuthenticationException() { public void authenticationWhenErrorThenOAuth2AuthenticationException() {
this.authorizationResponseBldr = OAuth2AuthorizationResponse.error("error").state("state"); // @formatter:off
this.authorizationResponseBldr = OAuth2AuthorizationResponse
.error("error")
.state("state");
// @formatter:on
assertThatExceptionOfType(OAuth2AuthenticationException.class) assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.manager.authenticate(loginToken()).block()); .isThrownBy(() -> this.manager.authenticate(loginToken()).block());
} }

View File

@ -60,12 +60,22 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString(); String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1").clientId("client-1") // @formatter:off
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) this.clientRegistration = ClientRegistration
.withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("https://client.com/callback/client-1").scope("read", "write") .redirectUri("https://client.com/callback/client-1")
.authorizationUri("https://provider.com/oauth2/authorize").tokenUri(tokenUri) .scope("read", "write")
.userInfoUri("https://provider.com/user").userNameAttributeName("id").clientName("client-1").build(); .authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri(tokenUri)
.userInfoUri("https://provider.com/user")
.userNameAttributeName("id")
.clientName("client-1")
.build();
// @formatter:on
} }
@After @After
@ -90,11 +100,17 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"read write\",\n" + " \"refresh_token\": \"refresh-token-1234\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"custom_parameter_2\": \"custom-value-2\"\n" + " \"token_type\": \"bearer\",\n"
+ "}\n"; + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
@ -121,8 +137,11 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()); this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest());
RecordedRequest recordedRequest = this.server.takeRequest(); RecordedRequest recordedRequest = this.server.takeRequest();
@ -131,8 +150,13 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.from(this.clientRegistration) ClientRegistration clientRegistration = this.from(this.clientRegistration)
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
@ -146,8 +170,13 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class) assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
@ -158,7 +187,11 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndMissingTokenTypeParameterThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\"\n" + "}\n"; // @formatter:off
String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class) assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
@ -169,9 +202,15 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n" + " \"scope\": \"read\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(this.authorizationCodeGrantRequest()); .getTokenResponse(this.authorizationCodeGrantRequest());
@ -180,9 +219,14 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() { public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"refresh_token\": \"refresh-token-1234\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"refresh_token\": \"refresh-token-1234\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
.getTokenResponse(this.authorizationCodeGrantRequest()); .getTokenResponse(this.authorizationCodeGrantRequest());
@ -201,12 +245,17 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"read write\",\n" + " \"refresh_token\": \"refresh-token-1234\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"token_type\": \"bearer\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"; + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n";
// "}\n"; // Make the JSON invalid/malformed // "}\n"; // Make the JSON invalid/malformed
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class) assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest())) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(this.authorizationCodeGrantRequest()))
@ -253,17 +302,21 @@ public class DefaultAuthorizationCodeTokenResponseClientTests {
} }
private ClientRegistration.Builder from(ClientRegistration registration) { private ClientRegistration.Builder from(ClientRegistration registration) {
// @formatter:off
return ClientRegistration.withRegistrationId(registration.getRegistrationId()) return ClientRegistration.withRegistrationId(registration.getRegistrationId())
.clientId(registration.getClientId()).clientSecret(registration.getClientSecret()) .clientId(registration.getClientId())
.clientSecret(registration.getClientSecret())
.clientAuthenticationMethod(registration.getClientAuthenticationMethod()) .clientAuthenticationMethod(registration.getClientAuthenticationMethod())
.authorizationGrantType(registration.getAuthorizationGrantType()) .authorizationGrantType(registration.getAuthorizationGrantType())
.redirectUri(registration.getRedirectUri()).scope(registration.getScopes()) .redirectUri(registration.getRedirectUri())
.scope(registration.getScopes())
.authorizationUri(registration.getProviderDetails().getAuthorizationUri()) .authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.tokenUri(registration.getProviderDetails().getTokenUri()) .tokenUri(registration.getProviderDetails().getTokenUri())
.userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri()) .userInfoUri(registration.getProviderDetails().getUserInfoEndpoint().getUri())
.userNameAttributeName( .userNameAttributeName(
registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName()) registration.getProviderDetails().getUserInfoEndpoint().getUserNameAttributeName())
.clientName(registration.getClientName()); .clientName(registration.getClientName());
// @formatter:on
} }
} }

View File

@ -57,10 +57,16 @@ public class DefaultClientCredentialsTokenResponseClientTests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString(); String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1").clientId("client-1") // @formatter:off
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).scope("read", "write") .clientId("client-1")
.tokenUri(tokenUri).build(); .clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri(tokenUri)
.build();
// @formatter:on
} }
@After @After
@ -70,12 +76,18 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() { public void setRequestEntityConverterWhenConverterIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null)); // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setRequestEntityConverter(null));
// @formatter:on
} }
@Test @Test
public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() { public void setRestOperationsWhenRestOperationsIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> this.tokenResponseClient.setRestOperations(null)); // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> this.tokenResponseClient.setRestOperations(null));
// @formatter:on
} }
@Test @Test
@ -85,10 +97,16 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"read write\",\n" + " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
@ -116,8 +134,13 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception { public void getTokenResponseWhenClientAuthenticationBasicThenAuthorizationHeaderIsSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -128,8 +151,13 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.from(this.clientRegistration) ClientRegistration clientRegistration = this.from(this.clientRegistration)
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
@ -145,8 +173,13 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -172,9 +205,14 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" + " \"scope\": \"read\"\n" String accessTokenSuccessResponse = "{\n"
+ "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -185,8 +223,13 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() { public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenAccessTokenHasDefaultScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -209,11 +252,16 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenMalformedResponseThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"read write\",\n" + " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"; + " \"token_type\": \"bearer\",\n"
// "}\n"; // Make the JSON invalid/malformed + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n";
// "}\n"; // Make the JSON invalid/malformed
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -225,7 +273,11 @@ public class DefaultClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; // @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration); this.clientRegistration);
@ -250,11 +302,15 @@ public class DefaultClientCredentialsTokenResponseClientTests {
} }
private ClientRegistration.Builder from(ClientRegistration registration) { private ClientRegistration.Builder from(ClientRegistration registration) {
// @formatter:off
return ClientRegistration.withRegistrationId(registration.getRegistrationId()) return ClientRegistration.withRegistrationId(registration.getRegistrationId())
.clientId(registration.getClientId()).clientSecret(registration.getClientSecret()) .clientId(registration.getClientId())
.clientSecret(registration.getClientSecret())
.clientAuthenticationMethod(registration.getClientAuthenticationMethod()) .clientAuthenticationMethod(registration.getClientAuthenticationMethod())
.authorizationGrantType(registration.getAuthorizationGrantType()).scope(registration.getScopes()) .authorizationGrantType(registration.getAuthorizationGrantType())
.scope(registration.getScopes())
.tokenUri(registration.getProviderDetails().getTokenUri()); .tokenUri(registration.getProviderDetails().getTokenUri());
// @formatter:on
} }
} }

View File

@ -88,8 +88,13 @@ public class DefaultPasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
@ -117,8 +122,13 @@ public class DefaultPasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
@ -134,8 +144,13 @@ public class DefaultPasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password); this.clientRegistrationBuilder.build(), this.username, this.password);
@ -148,9 +163,14 @@ public class DefaultPasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" + " \"scope\": \"read\"\n" String accessTokenSuccessResponse = "{\n"
+ "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password); this.clientRegistrationBuilder.build(), this.username, this.password);

View File

@ -92,8 +92,13 @@ public class DefaultRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
@ -137,8 +142,13 @@ public class DefaultRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
@ -151,9 +161,14 @@ public class DefaultRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" + " \"scope\": \"read\"\n" String accessTokenSuccessResponse = "{\n"
+ "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken,

View File

@ -75,11 +75,17 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\",\n" + " \"refresh_token\": \"refresh-token-1234\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"custom_parameter_2\": \"custom-value-2\"\n" + " \"token_type\": \"bearer\",\n"
+ "}\n"; + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"
+ "}\n";
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse)); .setBody(accessTokenSuccessResponse));
server.start(); server.start();
@ -127,11 +133,16 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(OAuth2AuthorizationException.class); this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("invalid_token_response")); this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\",\n" + " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"; + " \"token_type\": \"bearer\",\n"
// "}\n"; // Make the JSON invalid/malformed + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n";
// "}\n"; // Make the JSON invalid/malformed
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse)); .setBody(accessTokenSuccessResponse));
server.start(); server.start();
@ -160,7 +171,11 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(OAuth2AuthorizationException.class); this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("unauthorized_client")); this.exception.expectMessage(containsString("unauthorized_client"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; // @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setResponseCode(500).setBody(accessTokenErrorResponse)); .setResponseCode(500).setBody(accessTokenErrorResponse));
server.start(); server.start();
@ -200,8 +215,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
this.exception.expect(OAuth2AuthorizationException.class); this.exception.expect(OAuth2AuthorizationException.class);
this.exception.expectMessage(containsString("invalid_token_response")); this.exception.expectMessage(containsString("invalid_token_response"));
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse)); .setBody(accessTokenSuccessResponse));
server.start(); server.start();
@ -220,9 +240,14 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope()
throws Exception { throws Exception {
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse)); .setBody(accessTokenSuccessResponse));
server.start(); server.start();
@ -242,8 +267,13 @@ public class NimbusAuthorizationCodeTokenResponseClientTests {
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope()
throws Exception { throws Exception {
MockWebServer server = new MockWebServer(); MockWebServer server = new MockWebServer();
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) server.enqueue(new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(accessTokenSuccessResponse)); .setBody(accessTokenSuccessResponse));
server.start(); server.start();

View File

@ -48,21 +48,38 @@ public class OAuth2AuthorizationCodeGrantRequestEntityConverterTests {
private OAuth2AuthorizationCodeGrantRequestEntityConverter converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter(); private OAuth2AuthorizationCodeGrantRequestEntityConverter converter = new OAuth2AuthorizationCodeGrantRequestEntityConverter();
// @formatter:off
private ClientRegistration.Builder clientRegistrationBuilder = ClientRegistration private ClientRegistration.Builder clientRegistrationBuilder = ClientRegistration
.withRegistrationId("registration-1").clientId("client-1").clientSecret("secret") .withRegistrationId("registration-1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("https://client.com/callback/client-1").scope("read", "write") .redirectUri("https://client.com/callback/client-1")
.authorizationUri("https://provider.com/oauth2/authorize").tokenUri("https://provider.com/oauth2/token") .scope("read", "write")
.userInfoUri("https://provider.com/user").userNameAttributeName("id").clientName("client-1");
private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder = OAuth2AuthorizationRequest
.authorizationCode().clientId("client-1").state("state-1234")
.authorizationUri("https://provider.com/oauth2/authorize") .authorizationUri("https://provider.com/oauth2/authorize")
.redirectUri("https://client.com/callback/client-1").scopes(new HashSet(Arrays.asList("read", "write"))); .tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/user")
.userNameAttributeName("id")
.clientName("client-1");
// @formatter:on
// @formatter:off
private OAuth2AuthorizationRequest.Builder authorizationRequestBuilder = OAuth2AuthorizationRequest
.authorizationCode()
.clientId("client-1")
.state("state-1234")
.authorizationUri("https://provider.com/oauth2/authorize")
.redirectUri("https://client.com/callback/client-1")
.scopes(new HashSet(Arrays.asList("read", "write")));
// @formatter:on
// @formatter:off
private OAuth2AuthorizationResponse.Builder authorizationResponseBuilder = OAuth2AuthorizationResponse private OAuth2AuthorizationResponse.Builder authorizationResponseBuilder = OAuth2AuthorizationResponse
.success("code-1234").state("state-1234").redirectUri("https://client.com/callback/client-1"); .success("code-1234")
.state("state-1234")
.redirectUri("https://client.com/callback/client-1");
// @formatter:on
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test

View File

@ -44,11 +44,16 @@ public class OAuth2ClientCredentialsGrantRequestEntityConverterTests {
@Before @Before
public void setup() { public void setup() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1").clientSecret("secret") .clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).scope("read", "write") .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri("https://provider.com/oauth2/token").build(); .scope("read", "write")
.tokenUri("https://provider.com/oauth2/token")
.build();
// @formatter:on
this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); this.clientCredentialsGrantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
} }

View File

@ -37,10 +37,16 @@ public class OAuth2ClientCredentialsGrantRequestTests {
@Before @Before
public void setup() { public void setup() {
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1").clientId("client-1") // @formatter:off
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).scope("read", "write") .clientId("client-1")
.tokenUri("https://provider.com/oauth2/token").build(); .clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri("https://provider.com/oauth2/token")
.build();
// @formatter:on
} }
@Test @Test
@ -50,10 +56,15 @@ public class OAuth2ClientCredentialsGrantRequestTests {
@Test @Test
public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() { public void constructorWhenClientRegistrationInvalidGrantTypeThenThrowIllegalArgumentException() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1") ClientRegistration clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.clientId("client-1").authorizationGrantType(AuthorizationGrantType.IMPLICIT) .clientId("client-1")
.redirectUri("https://localhost:8080/redirect-uri").authorizationUri("https://provider.com/oauth2/auth") .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.clientName("Client 1").build(); .redirectUri("https://localhost:8080/redirect-uri")
.authorizationUri("https://provider.com/oauth2/auth")
.clientName("Client 1")
.build();
// @formatter:on
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration)).withMessage( .isThrownBy(() -> new OAuth2ClientCredentialsGrantRequest(clientRegistration)).withMessage(
"clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS"); "clientRegistration.authorizationGrantType must be AuthorizationGrantType.CLIENT_CREDENTIALS");

View File

@ -44,8 +44,12 @@ public class OAuth2PasswordGrantRequestEntityConverterTests {
@Before @Before
public void setup() { public void setup() {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.authorizationGrantType(AuthorizationGrantType.PASSWORD).scope("read", "write").build(); .authorizationGrantType(AuthorizationGrantType.PASSWORD)
.scope("read", "write")
.build();
// @formatter:on
this.passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password"); this.passwordGrantRequest = new OAuth2PasswordGrantRequest(clientRegistration, "user1", "password");
} }

View File

@ -64,7 +64,8 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
String tokenUri = this.server.url("/oauth2/token").toString(); String tokenUri = this.server.url("/oauth2/token").toString();
this.clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(tokenUri); this.clientRegistration = TestClientRegistrations.clientRegistration()
.tokenUri(tokenUri);
} }
@After @After
@ -74,11 +75,17 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\",\n" + " \"refresh_token\": \"refresh-token-1234\",\n" + " \"access_token\": \"access-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n" + " \"custom_parameter_2\": \"custom-value-2\"\n" + " \"token_type\": \"bearer\",\n"
+ "}\n"; + " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\",\n"
+ " \"refresh_token\": \"refresh-token-1234\",\n"
+ " \"custom_parameter_1\": \"custom-value-1\",\n"
+ " \"custom_parameter_2\": \"custom-value-2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
@ -198,8 +205,13 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ "\"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
assertThatExceptionOfType(OAuth2AuthorizationException.class) assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block()) .isThrownBy(() -> this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()).block())
@ -208,9 +220,14 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() { public void getTokenResponseWhenSuccessResponseIncludesScopeThenReturnAccessTokenResponseUsingResponseScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\"\n" + "}\n"; + "\"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.clientRegistration.scope("openid", "profile", "email", "address"); this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
@ -220,8 +237,13 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() { public void getTokenResponseWhenSuccessResponseDoesNotIncludeScopeThenReturnAccessTokenResponseUsingRequestedScope() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.clientRegistration.scope("openid", "profile", "email", "address"); this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient OAuth2AccessTokenResponse accessTokenResponse = this.tokenResponseClient
@ -257,9 +279,14 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
WebClient customClient = mock(WebClient.class); WebClient customClient = mock(WebClient.class);
given(customClient.post()).willReturn(WebClient.builder().build().post()); given(customClient.post()).willReturn(WebClient.builder().build().post());
this.tokenResponseClient.setWebClient(customClient); this.tokenResponseClient.setWebClient(customClient);
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenSuccessResponse = "{\n"
+ " \"scope\": \"openid profile\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"openid profile\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.clientRegistration.scope("openid", "profile", "email", "address"); this.clientRegistration.scope("openid", "profile", "email", "address");
OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest()) OAuth2AccessTokenResponse response = this.tokenResponseClient.getTokenResponse(authorizationCodeGrantRequest())
@ -270,8 +297,13 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier() public void getTokenResponseWhenOAuth2AuthorizationRequestContainsPkceParametersThenTokenRequestBodyShouldContainCodeVerifier()
throws Exception { throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block(); this.tokenResponseClient.getTokenResponse(pkceAuthorizationCodeGrantRequest()).block();
String body = this.server.takeRequest().getBody().readUtf8(); String body = this.server.takeRequest().getBody().readUtf8();
@ -287,13 +319,22 @@ public class WebClientReactiveAuthorizationCodeTokenResponseClientTests {
Map<String, Object> additionalParameters = new HashMap<>(); Map<String, Object> additionalParameters = new HashMap<>();
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234"); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE, "code-challenge-1234");
additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256"); additionalParameters.put(PkceParameterNames.CODE_CHALLENGE_METHOD, "S256");
// @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.clientId(registration.getClientId()).state("state") .clientId(registration.getClientId())
.state("state")
.authorizationUri(registration.getProviderDetails().getAuthorizationUri()) .authorizationUri(registration.getProviderDetails().getAuthorizationUri())
.redirectUri(registration.getRedirectUri()).scopes(registration.getScopes()).attributes(attributes) .redirectUri(registration.getRedirectUri())
.additionalParameters(additionalParameters).build(); .scopes(registration.getScopes())
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse.success("code").state("state") .attributes(attributes)
.redirectUri(registration.getRedirectUri()).build(); .additionalParameters(additionalParameters)
.build();
OAuth2AuthorizationResponse authorizationResponse = OAuth2AuthorizationResponse
.success("code")
.state("state")
.redirectUri(registration.getRedirectUri())
.build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse); authorizationResponse);
return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange); return new OAuth2AuthorizationCodeGrantRequest(registration, authorizationExchange);

View File

@ -68,9 +68,15 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenHeaderThenSuccess() throws Exception { public void getTokenResponseWhenHeaderThenSuccess() throws Exception {
enqueueJson("{\n" + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" // @formatter:off
+ " \"token_type\":\"bearer\",\n" + " \"expires_in\":3600,\n" enqueueJson("{\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + " \"scope\":\"create\"\n" + "}"); + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest( OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(
this.clientRegistration.build()); this.clientRegistration.build());
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
@ -86,9 +92,15 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
public void getTokenResponseWhenPostThenSuccess() throws Exception { public void getTokenResponseWhenPostThenSuccess() throws Exception {
ClientRegistration registration = this.clientRegistration ClientRegistration registration = this.clientRegistration
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
enqueueJson("{\n" + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" // @formatter:off
+ " \"token_type\":\"bearer\",\n" + " \"expires_in\":3600,\n" enqueueJson("{\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n" + " \"scope\":\"create\"\n" + "}"); + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\",\n"
+ " \"scope\":\"create\"\n"
+ "}");
// @formatter:on
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
RecordedRequest actualRequest = this.server.takeRequest(); RecordedRequest actualRequest = this.server.takeRequest();
@ -102,9 +114,14 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() { public void getTokenResponseWhenNoScopeThenClientRegistrationScopesDefaulted() {
ClientRegistration registration = this.clientRegistration.build(); ClientRegistration registration = this.clientRegistration.build();
enqueueJson("{\n" + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" // @formatter:off
+ " \"token_type\":\"bearer\",\n" + " \"expires_in\":3600,\n" enqueueJson("{\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + "}"); + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes()); assertThat(response.getAccessToken().getScopes()).isEqualTo(registration.getScopes());
@ -121,9 +138,14 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
given(customClient.post()).willReturn(WebClient.builder().build().post()); given(customClient.post()).willReturn(WebClient.builder().build().post());
this.client.setWebClient(customClient); this.client.setWebClient(customClient);
ClientRegistration registration = this.clientRegistration.build(); ClientRegistration registration = this.clientRegistration.build();
enqueueJson("{\n" + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n" // @formatter:off
+ " \"token_type\":\"bearer\",\n" + " \"expires_in\":3600,\n" enqueueJson("{\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n" + "}"); + " \"access_token\":\"MTQ0NjJkZmQ5OTM2NDE1ZTZjNGZmZjI3\",\n"
+ " \"token_type\":\"bearer\",\n"
+ " \"expires_in\":3600,\n"
+ " \"refresh_token\":\"IwOGYzYTlmM2YxOTQ5MGE3YmNmMDFkNTVk\"\n"
+ "}");
// @formatter:on
OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration); OAuth2ClientCredentialsGrantRequest request = new OAuth2ClientCredentialsGrantRequest(registration);
OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block(); OAuth2AccessTokenResponse response = this.client.getTokenResponse(request).block();
verify(customClient, atLeastOnce()).post(); verify(customClient, atLeastOnce()).post();
@ -142,8 +164,11 @@ public class WebClientReactiveClientCredentialsTokenResponseClientTests {
} }
private void enqueueUnexpectedResponse() { private void enqueueUnexpectedResponse() {
MockResponse response = new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) // @formatter:off
MockResponse response = new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setResponseCode(301); .setResponseCode(301);
// @formatter:on
this.server.enqueue(response); this.server.enqueue(response);
} }

View File

@ -81,8 +81,13 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
ClientRegistration clientRegistration = this.clientRegistrationBuilder.build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.build();
@ -111,8 +116,13 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
@ -128,8 +138,13 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password); this.clientRegistrationBuilder.build(), this.username, this.password);
@ -143,9 +158,14 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" + " \"scope\": \"read\"\n" String accessTokenSuccessResponse = "{\n"
+ "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password); this.clientRegistrationBuilder.build(), this.username, this.password);
@ -159,7 +179,11 @@ public class WebClientReactivePasswordTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; // @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest( OAuth2PasswordGrantRequest passwordGrantRequest = new OAuth2PasswordGrantRequest(
this.clientRegistrationBuilder.build(), this.username, this.password); this.clientRegistrationBuilder.build(), this.username, this.password);
@ -182,7 +206,11 @@ public class WebClientReactivePasswordTokenResponseClientTests {
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); // @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
} }
} }

View File

@ -87,8 +87,13 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception { public void getTokenResponseWhenSuccessResponseThenReturnAccessTokenResponse() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
Instant expiresAtBefore = Instant.now().plusSeconds(3600); Instant expiresAtBefore = Instant.now().plusSeconds(3600);
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
@ -115,8 +120,13 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception { public void getTokenResponseWhenClientAuthenticationPostThenFormParametersAreSent() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
ClientRegistration clientRegistration = this.clientRegistrationBuilder ClientRegistration clientRegistration = this.clientRegistrationBuilder
.clientAuthenticationMethod(ClientAuthenticationMethod.POST).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.POST).build();
@ -132,8 +142,13 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenSuccessResponseAndNotBearerTokenTypeThenThrowOAuth2AuthorizationException() {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"not-bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenSuccessResponse = "{\n"
+ " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"not-bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
@ -146,9 +161,14 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception { public void getTokenResponseWhenSuccessResponseIncludesScopeThenAccessTokenHasResponseScope() throws Exception {
String accessTokenSuccessResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" + " \"scope\": \"read\"\n" String accessTokenSuccessResponse = "{\n"
+ "}\n"; + " \"access_token\": \"access-token-1234\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenSuccessResponse)); this.server.enqueue(jsonResponse(accessTokenSuccessResponse));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken, this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken,
@ -163,7 +183,11 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
@Test @Test
public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() { public void getTokenResponseWhenErrorResponseThenThrowOAuth2AuthorizationException() {
String accessTokenErrorResponse = "{\n" + " \"error\": \"unauthorized_client\"\n" + "}\n"; // @formatter:off
String accessTokenErrorResponse = "{\n"
+ " \"error\": \"unauthorized_client\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400)); this.server.enqueue(jsonResponse(accessTokenErrorResponse).setResponseCode(400));
OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest( OAuth2RefreshTokenGrantRequest refreshTokenGrantRequest = new OAuth2RefreshTokenGrantRequest(
this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken); this.clientRegistrationBuilder.build(), this.accessToken, this.refreshToken);
@ -186,7 +210,11 @@ public class WebClientReactiveRefreshTokenTokenResponseClientTests {
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); // @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
} }
} }

View File

@ -36,8 +36,12 @@ public class OAuth2ErrorResponseErrorHandlerTests {
@Test @Test
public void handleErrorWhenErrorResponseBodyThenHandled() { public void handleErrorWhenErrorResponseBodyThenHandled() {
String errorResponse = "{\n" + " \"error\": \"unauthorized_client\",\n" // @formatter:off
+ " \"error_description\": \"The client is not authorized\"\n" + "}\n"; String errorResponse = "{\n"
+ " \"error\": \"unauthorized_client\",\n"
+ " \"error_description\": \"The client is not authorized\"\n"
+ "}\n";
// @formatter:on
MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST); MockClientHttpResponse response = new MockClientHttpResponse(errorResponse.getBytes(), HttpStatus.BAD_REQUEST);
assertThatExceptionOfType(OAuth2AuthorizationException.class) assertThatExceptionOfType(OAuth2AuthorizationException.class)
.isThrownBy(() -> this.errorHandler.handleError(response)) .isThrownBy(() -> this.errorHandler.handleError(response))

View File

@ -55,8 +55,11 @@ public class OAuth2AuthorizationRequestMixinTests {
Map<String, Object> additionalParameters = new LinkedHashMap<>(); Map<String, Object> additionalParameters = new LinkedHashMap<>();
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2"); additionalParameters.put("param2", "value2");
this.authorizationRequestBuilder = TestOAuth2AuthorizationRequests.request().scope("read", "write") // @formatter:off
this.authorizationRequestBuilder = TestOAuth2AuthorizationRequests.request()
.scope("read", "write")
.additionalParameters(additionalParameters); .additionalParameters(additionalParameters);
// @formatter:on
} }
@Test @Test
@ -69,8 +72,14 @@ public class OAuth2AuthorizationRequestMixinTests {
@Test @Test
public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception {
OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder.scopes(null).state(null) // @formatter:off
.additionalParameters(Map::clear).attributes(Map::clear).build(); OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestBuilder
.scopes(null)
.state(null)
.additionalParameters(Map::clear)
.attributes(Map::clear)
.build();
// @formatter:on
String expectedJson = asJson(authorizationRequest); String expectedJson = asJson(authorizationRequest);
String json = this.mapper.writeValueAsString(authorizationRequest); String json = this.mapper.writeValueAsString(authorizationRequest);
JSONAssert.assertEquals(expectedJson, json, true); JSONAssert.assertEquals(expectedJson, json, true);
@ -106,8 +115,13 @@ public class OAuth2AuthorizationRequestMixinTests {
@Test @Test
public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception {
// @formatter:off
OAuth2AuthorizationRequest expectedAuthorizationRequest = this.authorizationRequestBuilder.scopes(null) OAuth2AuthorizationRequest expectedAuthorizationRequest = this.authorizationRequestBuilder.scopes(null)
.state(null).additionalParameters(Map::clear).attributes(Map::clear).build(); .state(null)
.additionalParameters(Map::clear)
.attributes(Map::clear)
.build();
// @formatter:on
String json = asJson(expectedAuthorizationRequest); String json = asJson(expectedAuthorizationRequest);
OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class); OAuth2AuthorizationRequest authorizationRequest = this.mapper.readValue(json, OAuth2AuthorizationRequest.class);
assertThat(authorizationRequest.getAuthorizationUri()) assertThat(authorizationRequest.getAuthorizationUri())

View File

@ -67,8 +67,11 @@ public class OAuth2AuthorizedClientMixinTests {
Map<String, Object> providerConfigurationMetadata = new LinkedHashMap<>(); Map<String, Object> providerConfigurationMetadata = new LinkedHashMap<>();
providerConfigurationMetadata.put("config1", "value1"); providerConfigurationMetadata.put("config1", "value1");
providerConfigurationMetadata.put("config2", "value2"); providerConfigurationMetadata.put("config2", "value2");
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().scope("read", "write") // @formatter:off
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
.scope("read", "write")
.providerConfigurationMetadata(providerConfigurationMetadata); .providerConfigurationMetadata(providerConfigurationMetadata);
// @formatter:on
this.accessToken = TestOAuth2AccessTokens.scopes("read", "write"); this.accessToken = TestOAuth2AccessTokens.scopes("read", "write");
this.refreshToken = TestOAuth2RefreshTokens.refreshToken(); this.refreshToken = TestOAuth2RefreshTokens.refreshToken();
this.principalName = "principal-name"; this.principalName = "principal-name";
@ -85,8 +88,16 @@ public class OAuth2AuthorizedClientMixinTests {
@Test @Test
public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception { public void serializeWhenRequiredAttributesOnlyThenSerializes() throws Exception {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().clientSecret(null) // @formatter:off
.clientName(null).userInfoUri(null).userNameAttributeName(null).jwkSetUri(null).issuerUri(null).build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.clientSecret(null)
.clientName(null)
.userInfoUri(null)
.userNameAttributeName(null)
.jwkSetUri(null)
.issuerUri(null)
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, this.principalName, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(clientRegistration, this.principalName,
TestOAuth2AccessTokens.noScopes()); TestOAuth2AccessTokens.noScopes());
String expectedJson = asJson(authorizedClient); String expectedJson = asJson(authorizedClient);
@ -154,8 +165,16 @@ public class OAuth2AuthorizedClientMixinTests {
@Test @Test
public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception { public void deserializeWhenRequiredAttributesOnlyThenDeserializes() throws Exception {
ClientRegistration expectedClientRegistration = TestClientRegistrations.clientRegistration().clientSecret(null) // @formatter:off
.clientName(null).userInfoUri(null).userNameAttributeName(null).jwkSetUri(null).issuerUri(null).build(); ClientRegistration expectedClientRegistration = TestClientRegistrations.clientRegistration()
.clientSecret(null)
.clientName(null)
.userInfoUri(null)
.userNameAttributeName(null)
.jwkSetUri(null)
.issuerUri(null)
.build();
// @formatter:on
OAuth2AccessToken expectedAccessToken = TestOAuth2AccessTokens.noScopes(); OAuth2AccessToken expectedAccessToken = TestOAuth2AccessTokens.noScopes();
OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient(expectedClientRegistration, OAuth2AuthorizedClient expectedAuthorizedClient = new OAuth2AuthorizedClient(expectedClientRegistration,
this.principalName, expectedAccessToken); this.principalName, expectedAccessToken);

View File

@ -117,9 +117,15 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
} }
catch (NoSuchAlgorithmException ex) { catch (NoSuchAlgorithmException ex) {
} }
this.authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("openid", "profile", "email") // @formatter:off
.attributes(attributes).additionalParameters(additionalParameters).build(); this.authorizationRequest = TestOAuth2AuthorizationRequests.request()
this.authorizationResponse = TestOAuth2AuthorizationResponses.success().build(); .scope("openid", "profile", "email")
.attributes(attributes)
.additionalParameters(additionalParameters)
.build();
this.authorizationResponse = TestOAuth2AuthorizationResponses.success()
.build();
// @formatter:on
this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, this.authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class); this.accessTokenResponseClient = mock(OAuth2AccessTokenResponseClient.class);
@ -161,8 +167,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
@Test @Test
public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() { public void authenticateWhenAuthorizationRequestDoesNotContainOpenidScopeThenReturnNull() {
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request().scope("scope1") // @formatter:off
OAuth2AuthorizationRequest authorizationRequest = TestOAuth2AuthorizationRequests.request()
.scope("scope1")
.build(); .build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
this.authorizationResponse); this.authorizationResponse);
OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider OAuth2LoginAuthenticationToken authentication = (OAuth2LoginAuthenticationToken) this.authenticationProvider
@ -174,8 +183,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationErrorResponseThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE)); this.exception.expectMessage(containsString(OAuth2ErrorCodes.INVALID_SCOPE));
// @formatter:off
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error() OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.error()
.errorCode(OAuth2ErrorCodes.INVALID_SCOPE).build(); .errorCode(OAuth2ErrorCodes.INVALID_SCOPE)
.build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider this.authenticationProvider
@ -186,8 +198,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() { public void authenticateWhenAuthorizationResponseStateNotEqualAuthorizationRequestStateThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_state_parameter")); this.exception.expectMessage(containsString("invalid_state_parameter"));
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success().state("89012") // @formatter:off
OAuth2AuthorizationResponse authorizationResponse = TestOAuth2AuthorizationResponses.success()
.state("89012")
.build(); .build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(this.authorizationRequest,
authorizationResponse); authorizationResponse);
this.authenticationProvider this.authenticationProvider
@ -198,8 +213,12 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() { public void authenticateWhenTokenResponseDoesNotContainIdTokenThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_id_token")); this.exception.expectMessage(containsString("invalid_id_token"));
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withResponse(this.accessTokenSuccessResponse()).additionalParameters(Collections.emptyMap()).build(); .withResponse(this.accessTokenSuccessResponse())
.additionalParameters(Collections.emptyMap())
.build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(accessTokenResponse);
this.authenticationProvider this.authenticationProvider
.authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange)); .authenticate(new OAuth2LoginAuthenticationToken(this.clientRegistration, this.authorizationExchange));
@ -209,7 +228,11 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() { public void authenticateWhenJwkSetUriNotSetThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_signature_verifier")); this.exception.expectMessage(containsString("missing_signature_verifier"));
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().jwkSetUri(null).build(); // @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.jwkSetUri(null)
.build();
// @formatter:on
this.authenticationProvider this.authenticationProvider
.authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange)); .authenticate(new OAuth2LoginAuthenticationToken(clientRegistration, this.authorizationExchange));
} }
@ -323,9 +346,15 @@ public class OidcAuthorizationCodeAuthenticationProviderTests {
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2"); additionalParameters.put("param2", "value2");
additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token"); additionalParameters.put(OidcParameterNames.ID_TOKEN, "id-token");
return OAuth2AccessTokenResponse.withToken("access-token-1234").tokenType(OAuth2AccessToken.TokenType.BEARER) // @formatter:off
.expiresIn(expiresAt.getEpochSecond()).scopes(scopes).refreshToken("refresh-token-1234") return OAuth2AccessTokenResponse.withToken("access-token-1234")
.additionalParameters(additionalParameters).build(); .tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(expiresAt.getEpochSecond())
.scopes(scopes)
.refreshToken("refresh-token-1234")
.additionalParameters(additionalParameters)
.build();
// @formatter:on
} }
} }

View File

@ -90,10 +90,15 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Mock @Mock
private ReactiveJwtDecoder jwtDecoder; private ReactiveJwtDecoder jwtDecoder;
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration().scope("openid"); // @formatter:off
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration()
.scope("openid");
// @formatter:on
// @formatter:off
private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse.success("code") private OAuth2AuthorizationResponse.Builder authorizationResponseBldr = OAuth2AuthorizationResponse.success("code")
.state("state"); .state("state");
// @formatter:on
private OidcIdToken idToken = TestOidcIdTokens.idToken().build(); private OidcIdToken idToken = TestOidcIdTokens.idToken().build();
@ -153,7 +158,10 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticationWhenErrorThenOAuth2AuthenticationException() { public void authenticationWhenErrorThenOAuth2AuthenticationException() {
this.authorizationResponseBldr = OAuth2AuthorizationResponse.error("error").state("state"); // @formatter:off
this.authorizationResponseBldr = OAuth2AuthorizationResponse.error("error")
.state("state");
// @formatter:on
assertThatExceptionOfType(OAuth2AuthenticationException.class) assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.manager.authenticate(loginToken()).block()); .isThrownBy(() -> this.manager.authenticate(loginToken()).block());
} }
@ -167,10 +175,12 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticateWhenIdTokenValidationErrorThenOAuth2AuthenticationException() { public void authenticateWhenIdTokenValidationErrorThenOAuth2AuthenticationException() {
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters( .tokenType(OAuth2AccessToken.TokenType.BEARER)
Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.build(); .build();
// @formatter:on
given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse)); given(this.accessTokenResponseClient.getTokenResponse(any())).willReturn(Mono.just(accessTokenResponse));
given(this.jwtDecoder.decode(any())).willThrow(new JwtException("ID Token Validation Error")); given(this.jwtDecoder.decode(any())).willThrow(new JwtException("ID Token Validation Error"));
this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder); this.manager.setJwtDecoderFactory((c) -> this.jwtDecoder);
@ -181,10 +191,13 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticateWhenIdTokenInvalidNonceThenOAuth2AuthenticationException() { public void authenticateWhenIdTokenInvalidNonceThenOAuth2AuthenticationException() {
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters( .tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(
Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.build(); .build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -202,11 +215,13 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticationWhenOAuth2UserNotFoundThenEmpty() { public void authenticationWhenOAuth2UserNotFoundThenEmpty() {
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN, .additionalParameters(Collections.singletonMap(OidcParameterNames.ID_TOKEN,
"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.")) "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ."))
.build(); .build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -223,10 +238,14 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticationWhenOAuth2UserFoundThenSuccess() { public void authenticationWhenOAuth2UserFoundThenSuccess() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") // @formatter:off
.tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters( OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(
Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.build(); .build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -248,11 +267,15 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticationWhenRefreshTokenThenRefreshTokenInAuthorizedClient() { public void authenticationWhenRefreshTokenThenRefreshTokenInAuthorizedClient() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") // @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER) .tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters( .additionalParameters(
Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.refreshToken("refresh-token").build(); .refreshToken("refresh-token")
.build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -281,8 +304,13 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()); additionalParameters.put(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue());
additionalParameters.put("param1", "value1"); additionalParameters.put("param1", "value1");
additionalParameters.put("param2", "value2"); additionalParameters.put("param2", "value2");
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") // @formatter:off
.tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters(additionalParameters).build(); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(additionalParameters)
.build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -304,10 +332,13 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
@Test @Test
public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() { public void authenticateWhenAuthoritiesMapperSetThenReturnMappedAuthorities() {
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
// @formatter:off
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo") OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("foo")
.tokenType(OAuth2AccessToken.TokenType.BEARER).additionalParameters( .tokenType(OAuth2AccessToken.TokenType.BEARER)
.additionalParameters(
Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue())) Collections.singletonMap(OidcParameterNames.ID_TOKEN, this.idToken.getTokenValue()))
.build(); .build();
// @formatter:on
OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken(); OAuth2AuthorizationCodeAuthenticationToken authorizationCodeAuthentication = loginToken();
Map<String, Object> claims = new HashMap<>(); Map<String, Object> claims = new HashMap<>();
claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com"); claims.put(IdTokenClaimNames.ISS, "https://issuer.example.com");
@ -342,13 +373,19 @@ public class OidcAuthorizationCodeReactiveAuthenticationManagerTests {
} }
catch (NoSuchAlgorithmException ex) { catch (NoSuchAlgorithmException ex) {
} }
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode().state("state") // @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.state("state")
.clientId(clientRegistration.getClientId()) .clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri())
.redirectUri(clientRegistration.getRedirectUri()).scopes(clientRegistration.getScopes()) .redirectUri(clientRegistration.getRedirectUri())
.additionalParameters(additionalParameters).attributes(attributes).build(); .scopes(clientRegistration.getScopes())
.additionalParameters(additionalParameters)
.attributes(attributes).build();
OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr
.redirectUri(clientRegistration.getRedirectUri()).build(); .redirectUri(clientRegistration.getRedirectUri())
.build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse); authorizationResponse);
return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange); return new OAuth2AuthorizationCodeAuthenticationToken(clientRegistration, authorizationExchange);

View File

@ -50,7 +50,11 @@ import static org.mockito.Mockito.verify;
*/ */
public class OidcIdTokenDecoderFactoryTests { public class OidcIdTokenDecoderFactoryTests {
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration().scope("openid"); // @formatter:off
private ClientRegistration.Builder registration = TestClientRegistrations
.clientRegistration()
.scope("openid");
// @formatter:on
private OidcIdTokenDecoderFactory idTokenDecoderFactory; private OidcIdTokenDecoderFactory idTokenDecoderFactory;

View File

@ -72,26 +72,39 @@ public class OidcIdTokenValidatorTests {
@Test @Test
public void setClockSkewWhenNullThenThrowIllegalArgumentException() { public void setClockSkewWhenNullThenThrowIllegalArgumentException() {
OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build());
assertThatIllegalArgumentException().isThrownBy(() -> idTokenValidator.setClockSkew(null)); // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> idTokenValidator.setClockSkew(null));
// @formatter:on
} }
@Test @Test
public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() { public void setClockSkewWhenNegativeSecondsThenThrowIllegalArgumentException() {
OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build());
assertThatIllegalArgumentException().isThrownBy(() -> idTokenValidator.setClockSkew(Duration.ofSeconds(-1))); // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> idTokenValidator.setClockSkew(Duration.ofSeconds(-1)));
// @formatter:on
} }
@Test @Test
public void setClockWhenNullThenThrowIllegalArgumentException() { public void setClockWhenNullThenThrowIllegalArgumentException() {
OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build()); OidcIdTokenValidator idTokenValidator = new OidcIdTokenValidator(this.registration.build());
assertThatIllegalArgumentException().isThrownBy(() -> idTokenValidator.setClock(null)); // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> idTokenValidator.setClock(null));
// @formatter:on
} }
@Test @Test
public void validateWhenIssuerNullThenHasErrors() { public void validateWhenIssuerNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.ISS); this.claims.remove(IdTokenClaimNames.ISS);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS));
// @formatter:on
} }
@Test @Test
@ -101,8 +114,12 @@ public class OidcIdTokenValidatorTests {
* issuer in the ID Token, the validation must fail * issuer in the ID Token, the validation must fail
*/ */
this.registration = this.registration.issuerUri("https://somethingelse.com"); this.registration = this.registration.issuerUri("https://somethingelse.com");
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.ISS));
// @formatter:on
} }
@Test @Test
@ -118,22 +135,34 @@ public class OidcIdTokenValidatorTests {
@Test @Test
public void validateWhenSubNullThenHasErrors() { public void validateWhenSubNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.SUB); this.claims.remove(IdTokenClaimNames.SUB);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB));
// @formatter:on
} }
@Test @Test
public void validateWhenAudNullThenHasErrors() { public void validateWhenAudNullThenHasErrors() {
this.claims.remove(IdTokenClaimNames.AUD); this.claims.remove(IdTokenClaimNames.AUD);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD));
// @formatter:on
} }
@Test @Test
public void validateWhenIssuedAtNullThenHasErrors() { public void validateWhenIssuedAtNullThenHasErrors() {
this.issuedAt = null; this.issuedAt = null;
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT));
// @formatter:on
} }
@Test @Test
@ -146,15 +175,23 @@ public class OidcIdTokenValidatorTests {
@Test @Test
public void validateWhenAudMultipleAndAzpNullThenHasErrors() { public void validateWhenAudMultipleAndAzpNullThenHasErrors() {
this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other")); this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id", "other"));
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP));
// @formatter:on
} }
@Test @Test
public void validateWhenAzpNotClientIdThenHasErrors() { public void validateWhenAzpNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AZP, "other"); this.claims.put(IdTokenClaimNames.AZP, "other");
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP));
// @formatter:on
} }
@Test @Test
@ -168,15 +205,23 @@ public class OidcIdTokenValidatorTests {
public void validateWhenMultipleAudAzpNotClientIdThenHasErrors() { public void validateWhenMultipleAudAzpNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id-1", "client-id-2")); this.claims.put(IdTokenClaimNames.AUD, Arrays.asList("client-id-1", "client-id-2"));
this.claims.put(IdTokenClaimNames.AZP, "other-client"); this.claims.put(IdTokenClaimNames.AZP, "other-client");
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.AZP));
// @formatter:on
} }
@Test @Test
public void validateWhenAudNotClientIdThenHasErrors() { public void validateWhenAudNotClientIdThenHasErrors() {
this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client")); this.claims.put(IdTokenClaimNames.AUD, Collections.singletonList("other-client"));
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD));
// @formatter:on
} }
@Test @Test
@ -192,8 +237,12 @@ public class OidcIdTokenValidatorTests {
this.issuedAt = Instant.now().minus(Duration.ofSeconds(60)); this.issuedAt = Instant.now().minus(Duration.ofSeconds(60));
this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(30)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(30));
this.clockSkew = Duration.ofSeconds(0); this.clockSkew = Duration.ofSeconds(0);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP));
// @formatter:on
} }
@Test @Test
@ -209,8 +258,12 @@ public class OidcIdTokenValidatorTests {
this.issuedAt = Instant.now().plus(Duration.ofMinutes(1)); this.issuedAt = Instant.now().plus(Duration.ofMinutes(1));
this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(60)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(60));
this.clockSkew = Duration.ofMinutes(0); this.clockSkew = Duration.ofMinutes(0);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT));
// @formatter:on
} }
@Test @Test
@ -218,8 +271,12 @@ public class OidcIdTokenValidatorTests {
this.issuedAt = Instant.now().minus(Duration.ofSeconds(10)); this.issuedAt = Instant.now().minus(Duration.ofSeconds(10));
this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(5)); this.expiresAt = this.issuedAt.plus(Duration.ofSeconds(5));
this.clockSkew = Duration.ofSeconds(0); this.clockSkew = Duration.ofSeconds(0);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP));
// @formatter:on
} }
@Test @Test
@ -228,24 +285,38 @@ public class OidcIdTokenValidatorTests {
this.claims.remove(IdTokenClaimNames.AUD); this.claims.remove(IdTokenClaimNames.AUD);
this.issuedAt = null; this.issuedAt = null;
this.expiresAt = null; this.expiresAt = null;
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB)) .allMatch((msg) -> msg.contains(IdTokenClaimNames.SUB))
.allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD)) .allMatch((msg) -> msg.contains(IdTokenClaimNames.AUD))
.allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT)) .allMatch((msg) -> msg.contains(IdTokenClaimNames.IAT))
.allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP)); .allMatch((msg) -> msg.contains(IdTokenClaimNames.EXP));
// @formatter:on
} }
@Test @Test
public void validateFormatError() { public void validateFormatError() {
this.claims.remove(IdTokenClaimNames.SUB); this.claims.remove(IdTokenClaimNames.SUB);
this.claims.remove(IdTokenClaimNames.AUD); this.claims.remove(IdTokenClaimNames.AUD);
assertThat(this.validateIdToken()).hasSize(1).extracting(OAuth2Error::getDescription) // @formatter:off
assertThat(this.validateIdToken())
.hasSize(1)
.extracting(OAuth2Error::getDescription)
.allMatch((msg) -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}")); .allMatch((msg) -> msg.equals("The ID Token contains invalid claims: {sub=null, aud=null}"));
// @formatter:on
} }
private Collection<OAuth2Error> validateIdToken() { private Collection<OAuth2Error> validateIdToken() {
Jwt idToken = Jwt.withTokenValue("token").issuedAt(this.issuedAt).expiresAt(this.expiresAt) // @formatter:off
.headers((h) -> h.putAll(this.headers)).claims((c) -> c.putAll(this.claims)).build(); Jwt idToken = Jwt.withTokenValue("token")
.issuedAt(this.issuedAt)
.expiresAt(this.expiresAt)
.headers((h) -> h.putAll(this.headers))
.claims((c) -> c.putAll(this.claims))
.build();
// @formatter:on
OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build()); OidcIdTokenValidator validator = new OidcIdTokenValidator(this.registration.build());
validator.setClockSkew(this.clockSkew); validator.setClockSkew(this.clockSkew);
return validator.validate(idToken).getErrors(); return validator.validate(idToken).getErrors();

View File

@ -50,7 +50,10 @@ import static org.mockito.Mockito.verify;
*/ */
public class ReactiveOidcIdTokenDecoderFactoryTests { public class ReactiveOidcIdTokenDecoderFactoryTests {
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration().scope("openid"); // @formatter:off
private ClientRegistration.Builder registration = TestClientRegistrations.clientRegistration()
.scope("openid");
// @formatter:on
private ReactiveOidcIdTokenDecoderFactory idTokenDecoderFactory; private ReactiveOidcIdTokenDecoderFactory idTokenDecoderFactory;

View File

@ -157,9 +157,16 @@ public class OidcUserServiceTests {
// gh-6886 // gh-6886
@Test @Test
public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesMatchThenUserInfoEndpointRequested() { public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesMatchThenUserInfoEndpointRequested() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -173,9 +180,16 @@ public class OidcUserServiceTests {
// gh-6886 // gh-6886
@Test @Test
public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesEmptyThenUserInfoEndpointRequested() { public void loadUserWhenNonStandardScopesAuthorizedAndAccessibleScopesEmptyThenUserInfoEndpointRequested() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -189,9 +203,16 @@ public class OidcUserServiceTests {
// gh-6886 // gh-6886
@Test @Test
public void loadUserWhenStandardScopesAuthorizedThenUserInfoEndpointRequested() { public void loadUserWhenStandardScopesAuthorizedThenUserInfoEndpointRequested() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -202,9 +223,16 @@ public class OidcUserServiceTests {
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -234,8 +262,12 @@ public class OidcUserServiceTests {
public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseAndUserInfoSubjectIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("invalid_user_info_response")); this.exception.expectMessage(containsString("invalid_user_info_response"));
String userInfoResponse = "{\n" + " \"email\": \"full_name@provider.com\",\n" + " \"name\": \"full name\"\n" // @formatter:off
String userInfoResponse = "{\n"
+ " \"email\": \"full_name@provider.com\",\n"
+ " \"name\": \"full name\"\n"
+ "}\n"; + "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -259,10 +291,16 @@ public class OidcUserServiceTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString( this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n"; + " \"sub\": \"subject1\",\n"
// "}\n"; // Make the JSON invalid/malformed + " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -292,9 +330,16 @@ public class OidcUserServiceTests {
@Test @Test
public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() { public void loadUserWhenCustomUserNameAttributeNameThenGetNameReturnsCustomUserName() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -307,9 +352,16 @@ public class OidcUserServiceTests {
// gh-5294 // gh-5294
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -321,9 +373,16 @@ public class OidcUserServiceTests {
// gh-5500 // gh-5500
@Test @Test
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -338,9 +397,16 @@ public class OidcUserServiceTests {
// gh-5500 // gh-5500
@Test @Test
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -355,9 +421,16 @@ public class OidcUserServiceTests {
@Test @Test
public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() { public void loadUserWhenCustomClaimTypeConverterFactorySetThenApplied() {
String userInfoResponse = "{\n" + " \"sub\": \"subject1\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"given_name\": \"first\",\n" + " \"family_name\": \"last\",\n" String userInfoResponse = "{\n"
+ " \"preferred_username\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"sub\": \"subject1\",\n"
+ " \"name\": \"first last\",\n"
+ " \"given_name\": \"first\",\n"
+ " \"family_name\": \"last\",\n"
+ " \"preferred_username\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -395,7 +468,11 @@ public class OidcUserServiceTests {
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); // @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
} }
} }

View File

@ -49,9 +49,12 @@ import static org.mockito.Mockito.mock;
@RunWith(MockitoJUnitRunner.class) @RunWith(MockitoJUnitRunner.class)
public class OidcClientInitiatedLogoutSuccessHandlerTests { public class OidcClientInitiatedLogoutSuccessHandlerTests {
ClientRegistration registration = TestClientRegistrations.clientRegistration() // @formatter:off
ClientRegistration registration = TestClientRegistrations
.clientRegistration()
.providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint")) .providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint"))
.build(); .build();
// @formatter:on
ClientRegistrationRepository repository = new InMemoryClientRegistrationRepository(this.registration); ClientRegistrationRepository repository = new InMemoryClientRegistrationRepository(this.registration);

View File

@ -51,9 +51,12 @@ import static org.mockito.Mockito.mock;
*/ */
public class OidcClientInitiatedServerLogoutSuccessHandlerTests { public class OidcClientInitiatedServerLogoutSuccessHandlerTests {
ClientRegistration registration = TestClientRegistrations.clientRegistration() // @formatter:off
ClientRegistration registration = TestClientRegistrations
.clientRegistration()
.providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint")) .providerConfigurationMetadata(Collections.singletonMap("end_session_endpoint", "https://endpoint"))
.build(); .build();
// @formatter:on
ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository( ReactiveClientRegistrationRepository repository = new InMemoryReactiveClientRegistrationRepository(
this.registration); this.registration);

View File

@ -72,21 +72,42 @@ public class ClientRegistrationTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationGrantTypeIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationGrantTypeIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC).authorizationGrantType(null) ClientRegistration.withRegistrationId(REGISTRATION_ID)
.redirectUri(REDIRECT_URI).scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID)
.tokenUri(TOKEN_URI).userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI) .clientSecret(CLIENT_SECRET)
.clientName(CLIENT_NAME).build(); .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(null)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() { public void buildWhenAuthorizationCodeGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .clientId(CLIENT_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .clientSecret(CLIENT_SECRET)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).issuerUri(ISSUER_URI) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.providerConfigurationMetadata(PROVIDER_CONFIGURATION_METADATA).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.issuerUri(ISSUER_URI)
.providerConfigurationMetadata(PROVIDER_CONFIGURATION_METADATA)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
@ -107,172 +128,308 @@ public class ClientRegistrationTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationCodeGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(null).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(null)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationCodeGrantClientIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(null).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(null)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenDefaultToEmpty() { public void buildWhenAuthorizationCodeGrantClientSecretIsNullThenDefaultToEmpty() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(null).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .clientSecret(null)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
assertThat(clientRegistration.getClientSecret()).isEqualTo(""); assertThat(clientRegistration.getClientSecret()).isEqualTo("");
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .clientSecret(CLIENT_SECRET)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretNullThenDefaultToNone() { public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretNullThenDefaultToNone() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(null) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .clientSecret(null)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE);
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretBlankThenDefaultToNone() { public void buildWhenAuthorizationCodeGrantClientAuthenticationMethodNotProvidedAndClientSecretBlankThenDefaultToNone() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(" ").authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .clientId(CLIENT_ID)
.redirectUri(REDIRECT_URI).scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientSecret(" ")
.tokenUri(TOKEN_URI).userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.clientName(CLIENT_NAME).build(); .redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.NONE);
assertThat(clientRegistration.getClientSecret()).isEqualTo(""); assertThat(clientRegistration.getClientSecret()).isEqualTo("");
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantRedirectUriIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationCodeGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(null) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(null)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
// gh-5494 // gh-5494
@Test @Test
public void buildWhenAuthorizationCodeGrantScopeIsNullThenScopeNotRequired() { public void buildWhenAuthorizationCodeGrantScopeIsNullThenScopeNotRequired() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope((String[]) null).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope((String[]) null)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationCodeGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(null).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope(SCOPES.toArray(new String[0]))
.authorizationUri(null)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenAuthorizationCodeGrantTokenUriIsNullThenThrowIllegalArgumentException() { public void buildWhenAuthorizationCodeGrantTokenUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(null) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(null)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantClientNameNotProvidedThenDefaultToRegistrationId() { public void buildWhenAuthorizationCodeGrantClientNameNotProvidedThenDefaultToRegistrationId() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).jwkSetUri(JWK_SET_URI).build(); .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.jwkSetUri(JWK_SET_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId()); assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId());
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantScopeDoesNotContainOpenidThenJwkSetUriNotRequired() { public void buildWhenAuthorizationCodeGrantScopeDoesNotContainOpenidThenJwkSetUriNotRequired() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("scope1").authorizationUri(AUTHORIZATION_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).tokenUri(TOKEN_URI).clientName(CLIENT_NAME) .scope("scope1")
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
} }
// gh-5494 // gh-5494
@Test @Test
public void buildWhenAuthorizationCodeGrantScopeIsNullThenJwkSetUriNotRequired() { public void buildWhenAuthorizationCodeGrantScopeIsNullThenJwkSetUriNotRequired() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) // @formatter:off
ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI).clientName(CLIENT_NAME).build(); .redirectUri(REDIRECT_URI)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataIsNullThenDefaultToEmpty() { public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataIsNullThenDefaultToEmpty() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER).providerConfigurationMetadata(null) .scope(SCOPES.toArray(new String[0]))
.jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME).build(); .authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.providerConfigurationMetadata(null)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull();
assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty();
} }
@Test @Test
public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataEmptyThenIsEmpty() { public void buildWhenAuthorizationCodeGrantProviderConfigurationMetadataEmptyThenIsEmpty() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.HEADER) .userInfoAuthenticationMethod(AuthenticationMethod.HEADER)
.providerConfigurationMetadata(Collections.emptyMap()).jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME) .providerConfigurationMetadata(Collections.emptyMap())
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build(); .build();
// @formatter:on
assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isNotNull();
assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty(); assertThat(clientRegistration.getProviderDetails().getConfigurationMetadata()).isEmpty();
} }
@Test @Test
public void buildWhenImplicitGrantAllAttributesProvidedThenAllAttributesAreSet() { public void buildWhenImplicitGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT); assertThat(registration.getAuthorizationGrantType()).isEqualTo(AuthorizationGrantType.IMPLICIT);
@ -286,72 +443,129 @@ public class ClientRegistrationTests {
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { public void buildWhenImplicitGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(null).clientId(CLIENT_ID) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) ClientRegistration.withRegistrationId(null)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantClientIdIsNullThenThrowIllegalArgumentException() { public void buildWhenImplicitGrantClientIdIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(null) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) ClientRegistration.withRegistrationId(REGISTRATION_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientId(null)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantRedirectUriIsNullThenThrowIllegalArgumentException() { public void buildWhenImplicitGrantRedirectUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(null) ClientRegistration.withRegistrationId(REGISTRATION_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(null)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
// gh-5494 // gh-5494
@Test @Test
public void buildWhenImplicitGrantScopeIsNullThenScopeNotRequired() { public void buildWhenImplicitGrantScopeIsNullThenScopeNotRequired() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) ClientRegistration.withRegistrationId(REGISTRATION_ID)
.scope((String[]) null).authorizationUri(AUTHORIZATION_URI) .clientId(CLIENT_ID)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope((String[]) null)
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = IllegalArgumentException.class)
public void buildWhenImplicitGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() { public void buildWhenImplicitGrantAuthorizationUriIsNullThenThrowIllegalArgumentException() {
ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) ClientRegistration.withRegistrationId(REGISTRATION_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(null) .clientId(CLIENT_ID)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).clientName(CLIENT_NAME).build(); .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(null)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
} }
@Test @Test
public void buildWhenImplicitGrantClientNameNotProvidedThenDefaultToRegistrationId() { public void buildWhenImplicitGrantClientNameNotProvidedThenDefaultToRegistrationId() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).authorizationGrantType(AuthorizationGrantType.IMPLICIT).redirectUri(REDIRECT_URI) .clientId(CLIENT_ID)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI) .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM).build(); .redirectUri(REDIRECT_URI)
.scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.userInfoAuthenticationMethod(AuthenticationMethod.FORM)
.build();
// @formatter:on
assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId()); assertThat(clientRegistration.getClientName()).isEqualTo(clientRegistration.getRegistrationId());
} }
@Test @Test
public void buildWhenOverrideRegistrationIdThenOverridden() { public void buildWhenOverrideRegistrationIdThenOverridden() {
String overriddenId = "override"; String overriddenId = "override";
// @formatter:off
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.registrationId(overriddenId).clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .registrationId(overriddenId)
.clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope(SCOPES.toArray(new String[0])).authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI) .redirectUri(REDIRECT_URI)
.jwkSetUri(JWK_SET_URI).clientName(CLIENT_NAME).build(); .scope(SCOPES.toArray(new String[0]))
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.jwkSetUri(JWK_SET_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(overriddenId); assertThat(registration.getRegistrationId()).isEqualTo(overriddenId);
} }
@Test @Test
public void buildWhenClientCredentialsGrantAllAttributesProvidedThenAllAttributesAreSet() { public void buildWhenClientCredentialsGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).scope(SCOPES.toArray(new String[0])) .clientId(CLIENT_ID)
.tokenUri(TOKEN_URI).clientName(CLIENT_NAME).build(); .clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope(SCOPES.toArray(new String[0]))
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
@ -379,17 +593,28 @@ public class ClientRegistrationTests {
@Test @Test
public void buildWhenClientCredentialsGrantClientSecretIsNullThenDefaultToEmpty() { public void buildWhenClientCredentialsGrantClientSecretIsNullThenDefaultToEmpty() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(null).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).tokenUri(TOKEN_URI).build(); .clientSecret(null)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(TOKEN_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientSecret()).isEqualTo(""); assertThat(clientRegistration.getClientSecret()).isEqualTo("");
} }
@Test @Test
public void buildWhenClientCredentialsGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { public void buildWhenClientCredentialsGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).tokenUri(TOKEN_URI).build(); .clientSecret(CLIENT_SECRET)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.tokenUri(TOKEN_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
} }
@ -416,10 +641,17 @@ public class ClientRegistrationTests {
@Test @Test
public void buildWhenPasswordGrantAllAttributesProvidedThenAllAttributesAreSet() { public void buildWhenPasswordGrantAllAttributesProvidedThenAllAttributesAreSet() {
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.authorizationGrantType(AuthorizationGrantType.PASSWORD).scope(SCOPES.toArray(new String[0])) .clientId(CLIENT_ID)
.tokenUri(TOKEN_URI).clientName(CLIENT_NAME).build(); .clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.scope(SCOPES.toArray(new String[0]))
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
@ -432,50 +664,91 @@ public class ClientRegistrationTests {
@Test @Test
public void buildWhenPasswordGrantRegistrationIdIsNullThenThrowIllegalArgumentException() { public void buildWhenPasswordGrantRegistrationIdIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> ClientRegistration.withRegistrationId(null).clientId(CLIENT_ID) .isThrownBy(() -> ClientRegistration.withRegistrationId(null)
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.PASSWORD).tokenUri(TOKEN_URI).build()); .clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri(TOKEN_URI)
.build()
);
// @formatter:on
} }
@Test @Test
public void buildWhenPasswordGrantClientIdIsNullThenThrowIllegalArgumentException() { public void buildWhenPasswordGrantClientIdIsNullThenThrowIllegalArgumentException() {
assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID) // @formatter:off
.clientId(null).clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistration
.authorizationGrantType(AuthorizationGrantType.PASSWORD).tokenUri(TOKEN_URI).build()); .withRegistrationId(REGISTRATION_ID)
.clientId(null)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri(TOKEN_URI)
.build()
);
// @formatter:on
} }
@Test @Test
public void buildWhenPasswordGrantClientSecretIsNullThenDefaultToEmpty() { public void buildWhenPasswordGrantClientSecretIsNullThenDefaultToEmpty() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(null).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.PASSWORD).tokenUri(TOKEN_URI).build(); .clientSecret(null)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri(TOKEN_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientSecret()).isEqualTo(""); assertThat(clientRegistration.getClientSecret()).isEqualTo("");
} }
@Test @Test
public void buildWhenPasswordGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() { public void buildWhenPasswordGrantClientAuthenticationMethodNotProvidedThenDefaultToBasic() {
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).clientSecret(CLIENT_SECRET).authorizationGrantType(AuthorizationGrantType.PASSWORD) .clientId(CLIENT_ID)
.tokenUri(TOKEN_URI).build(); .clientSecret(CLIENT_SECRET)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri(TOKEN_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(ClientAuthenticationMethod.BASIC);
} }
@Test @Test
public void buildWhenPasswordGrantTokenUriIsNullThenThrowIllegalArgumentException() { public void buildWhenPasswordGrantTokenUriIsNullThenThrowIllegalArgumentException() {
// @formatter:off
assertThatIllegalArgumentException() assertThatIllegalArgumentException()
.isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) .isThrownBy(() -> ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId(CLIENT_ID)
.authorizationGrantType(AuthorizationGrantType.PASSWORD).tokenUri(null).build()); .clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri(null)
.build()
);
// @formatter:on
} }
@Test @Test
public void buildWhenCustomGrantAllAttributesProvidedThenAllAttributesAreSet() { public void buildWhenCustomGrantAllAttributesProvidedThenAllAttributesAreSet() {
AuthorizationGrantType customGrantType = new AuthorizationGrantType("CUSTOM"); AuthorizationGrantType customGrantType = new AuthorizationGrantType("CUSTOM");
ClientRegistration registration = ClientRegistration.withRegistrationId(REGISTRATION_ID).clientId(CLIENT_ID) // @formatter:off
.clientSecret(CLIENT_SECRET).clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) ClientRegistration registration = ClientRegistration
.authorizationGrantType(customGrantType).scope(SCOPES.toArray(new String[0])).tokenUri(TOKEN_URI) .withRegistrationId(REGISTRATION_ID)
.clientName(CLIENT_NAME).build(); .clientId(CLIENT_ID)
.clientSecret(CLIENT_SECRET)
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(customGrantType)
.scope(SCOPES.toArray(new String[0]))
.tokenUri(TOKEN_URI)
.clientName(CLIENT_NAME)
.build();
// @formatter:on
assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID); assertThat(registration.getRegistrationId()).isEqualTo(REGISTRATION_ID);
assertThat(registration.getClientId()).isEqualTo(CLIENT_ID); assertThat(registration.getClientId()).isEqualTo(CLIENT_ID);
assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET); assertThat(registration.getClientSecret()).isEqualTo(CLIENT_SECRET);
@ -532,9 +805,13 @@ public class ClientRegistrationTests {
@Test @Test
public void buildWhenClientRegistrationValuesOverriddenThenPropagated() { public void buildWhenClientRegistrationValuesOverriddenThenPropagated() {
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build(); ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().build();
// @formatter:off
ClientRegistration updated = ClientRegistration.withClientRegistration(clientRegistration) ClientRegistration updated = ClientRegistration.withClientRegistration(clientRegistration)
.clientSecret("a-new-secret").scope("a-new-scope") .clientSecret("a-new-secret")
.providerConfigurationMetadata(Collections.singletonMap("a-new-config", "a-new-value")).build(); .scope("a-new-scope")
.providerConfigurationMetadata(Collections.singletonMap("a-new-config", "a-new-value"))
.build();
// @formatter:on
assertThat(clientRegistration.getClientSecret()).isNotEqualTo(updated.getClientSecret()); assertThat(clientRegistration.getClientSecret()).isNotEqualTo(updated.getClientSecret());
assertThat(updated.getClientSecret()).isEqualTo("a-new-secret"); assertThat(updated.getClientSecret()).isEqualTo("a-new-secret");
assertThat(clientRegistration.getScopes()).doesNotContain("a-new-scope"); assertThat(clientRegistration.getScopes()).doesNotContain("a-new-scope");
@ -549,10 +826,16 @@ public class ClientRegistrationTests {
@Test @Test
public void buildWhenCustomClientAuthenticationMethodProvidedThenSet() { public void buildWhenCustomClientAuthenticationMethodProvidedThenSet() {
ClientAuthenticationMethod clientAuthenticationMethod = new ClientAuthenticationMethod("tls_client_auth"); ClientAuthenticationMethod clientAuthenticationMethod = new ClientAuthenticationMethod("tls_client_auth");
// @formatter:off
ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID) ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(REGISTRATION_ID)
.clientId(CLIENT_ID).authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .clientId(CLIENT_ID)
.clientAuthenticationMethod(clientAuthenticationMethod).redirectUri(REDIRECT_URI) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.authorizationUri(AUTHORIZATION_URI).tokenUri(TOKEN_URI).build(); .clientAuthenticationMethod(clientAuthenticationMethod)
.redirectUri(REDIRECT_URI)
.authorizationUri(AUTHORIZATION_URI)
.tokenUri(TOKEN_URI)
.build();
// @formatter:on
assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(clientAuthenticationMethod); assertThat(clientRegistration.getClientAuthenticationMethod()).isEqualTo(clientAuthenticationMethod);
} }

View File

@ -48,27 +48,61 @@ public class ClientRegistrationsTests {
/** /**
* Contains all optional parameters that are found in ClientRegistration * Contains all optional parameters that are found in ClientRegistration
*/ */
// @formatter:off
private static final String DEFAULT_RESPONSE = "{\n" private static final String DEFAULT_RESPONSE = "{\n"
+ " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n" + " \"authorization_endpoint\": \"https://example.com/o/oauth2/v2/auth\", \n"
+ " \"claims_supported\": [\n" + " \"aud\", \n" + " \"email\", \n" + " \"claims_supported\": [\n"
+ " \"email_verified\", \n" + " \"exp\", \n" + " \"family_name\", \n" + " \"aud\", \n"
+ " \"given_name\", \n" + " \"iat\", \n" + " \"iss\", \n" + " \"locale\", \n" + " \"email\", \n"
+ " \"name\", \n" + " \"picture\", \n" + " \"sub\"\n" + " ], \n" + " \"email_verified\", \n"
+ " \"code_challenge_methods_supported\": [\n" + " \"plain\", \n" + " \"S256\"\n" + " \"exp\", \n"
+ " ], \n" + " \"id_token_signing_alg_values_supported\": [\n" + " \"RS256\"\n" + " ], \n" + " \"family_name\", \n"
+ " \"given_name\", \n"
+ " \"iat\", \n"
+ " \"iss\", \n"
+ " \"locale\", \n"
+ " \"name\", \n"
+ " \"picture\", \n"
+ " \"sub\"\n"
+ " ], \n"
+ " \"code_challenge_methods_supported\": [\n"
+ " \"plain\", \n"
+ " \"S256\"\n"
+ " ], \n"
+ " \"id_token_signing_alg_values_supported\": [\n"
+ " \"RS256\"\n"
+ " ], \n"
+ " \"issuer\": \"https://example.com\", \n" + " \"issuer\": \"https://example.com\", \n"
+ " \"jwks_uri\": \"https://example.com/oauth2/v3/certs\", \n" + " \"response_types_supported\": [\n" + " \"jwks_uri\": \"https://example.com/oauth2/v3/certs\", \n"
+ " \"code\", \n" + " \"token\", \n" + " \"id_token\", \n" + " \"response_types_supported\": [\n"
+ " \"code token\", \n" + " \"code id_token\", \n" + " \"token id_token\", \n" + " \"code\", \n"
+ " \"code token id_token\", \n" + " \"none\"\n" + " ], \n" + " \"token\", \n"
+ " \"id_token\", \n"
+ " \"code token\", \n"
+ " \"code id_token\", \n"
+ " \"token id_token\", \n"
+ " \"code token id_token\", \n"
+ " \"none\"\n"
+ " ], \n"
+ " \"revocation_endpoint\": \"https://example.com/o/oauth2/revoke\", \n" + " \"revocation_endpoint\": \"https://example.com/o/oauth2/revoke\", \n"
+ " \"scopes_supported\": [\n" + " \"openid\", \n" + " \"email\", \n" + " \"scopes_supported\": [\n"
+ " \"profile\"\n" + " ], \n" + " \"subject_types_supported\": [\n" + " \"public\"\n" + " \"openid\", \n"
+ " ], \n" + " \"grant_types_supported\" : [\"authorization_code\"], \n" + " \"email\", \n"
+ " \"profile\"\n"
+ " ], \n"
+ " \"subject_types_supported\": [\n"
+ " \"public\"\n"
+ " ], \n"
+ " \"grant_types_supported\" : [\"authorization_code\"], \n"
+ " \"token_endpoint\": \"https://example.com/oauth2/v4/token\", \n" + " \"token_endpoint\": \"https://example.com/oauth2/v4/token\", \n"
+ " \"token_endpoint_auth_methods_supported\": [\n" + " \"client_secret_post\", \n" + " \"token_endpoint_auth_methods_supported\": [\n"
+ " \"client_secret_basic\", \n" + " \"none\"\n" + " ], \n" + " \"client_secret_post\", \n"
+ " \"userinfo_endpoint\": \"https://example.com/oauth2/v3/userinfo\"\n" + "}"; + " \"client_secret_basic\", \n"
+ " \"none\"\n"
+ " ], \n"
+ " \"userinfo_endpoint\": \"https://example.com/oauth2/v3/userinfo\"\n"
+ "}";
// @formatter:on
private MockWebServer server; private MockWebServer server;
@ -301,31 +335,43 @@ public class ClientRegistrationsTests {
@Test @Test
public void issuerWhenTokenEndpointAuthMethodsInvalidThenException() { public void issuerWhenTokenEndpointAuthMethodsInvalidThenException() {
this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth")); this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth"));
assertThatIllegalArgumentException().isThrownBy(() -> registration("")) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> registration(""))
.withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and " .withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and "
+ "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer
+ "\" returned a configuration of [tls_client_auth]"); + "\" returned a configuration of [tls_client_auth]");
// @formatter:on
} }
@Test @Test
public void issuerWhenOAuth2TokenEndpointAuthMethodsInvalidThenException() { public void issuerWhenOAuth2TokenEndpointAuthMethodsInvalidThenException() {
this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth")); this.response.put("token_endpoint_auth_methods_supported", Arrays.asList("tls_client_auth"));
assertThatIllegalArgumentException().isThrownBy(() -> registrationOAuth2("", null)) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> registrationOAuth2("", null))
.withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and " .withMessageContaining("Only ClientAuthenticationMethod.BASIC, ClientAuthenticationMethod.POST and "
+ "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer + "ClientAuthenticationMethod.NONE are supported. The issuer \"" + this.issuer
+ "\" returned a configuration of [tls_client_auth]"); + "\" returned a configuration of [tls_client_auth]");
// @formatter:on
} }
@Test @Test
public void issuerWhenOAuth2EmptyStringThenMeaningfulErrorMessage() { public void issuerWhenOAuth2EmptyStringThenMeaningfulErrorMessage() {
assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistrations.fromIssuerLocation("")) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> ClientRegistrations.fromIssuerLocation(""))
.withMessageContaining("issuer cannot be empty"); .withMessageContaining("issuer cannot be empty");
// @formatter:on
} }
@Test @Test
public void issuerWhenEmptyStringThenMeaningfulErrorMessage() { public void issuerWhenEmptyStringThenMeaningfulErrorMessage() {
assertThatIllegalArgumentException().isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation("")) // @formatter:off
assertThatIllegalArgumentException()
.isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation(""))
.withMessageContaining("issuer cannot be empty"); .withMessageContaining("issuer cannot be empty");
// @formatter:on
} }
@Test @Test
@ -335,9 +381,12 @@ public class ClientRegistrationsTests {
MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE,
MediaType.APPLICATION_JSON_VALUE); MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(mockResponse); this.server.enqueue(mockResponse);
assertThatIllegalStateException().isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation(this.issuer)) // @formatter:off
assertThatIllegalStateException()
.isThrownBy(() -> ClientRegistrations.fromOidcIssuerLocation(this.issuer))
.withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata did " .withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata did "
+ "not match the requested issuer \"" + this.issuer + "\""); + "not match the requested issuer \"" + this.issuer + "\"");
// @formatter:on
} }
@Test @Test
@ -347,20 +396,28 @@ public class ClientRegistrationsTests {
MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE,
MediaType.APPLICATION_JSON_VALUE); MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(mockResponse); this.server.enqueue(mockResponse);
assertThatIllegalStateException().isThrownBy(() -> ClientRegistrations.fromIssuerLocation(this.issuer)) // @formatter:off
assertThatIllegalStateException()
.isThrownBy(() -> ClientRegistrations.fromIssuerLocation(this.issuer))
.withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata " .withMessageContaining("The Issuer \"https://example.com\" provided in the configuration metadata "
+ "did not match the requested issuer \"" + this.issuer + "\""); + "did not match the requested issuer \"" + this.issuer + "\"");
// @formatter:on
} }
private ClientRegistration.Builder registration(String path) throws Exception { private ClientRegistration.Builder registration(String path) throws Exception {
this.issuer = createIssuerFromServer(path); this.issuer = createIssuerFromServer(path);
this.response.put("issuer", this.issuer); this.response.put("issuer", this.issuer);
String body = this.mapper.writeValueAsString(this.response); String body = this.mapper.writeValueAsString(this.response);
MockResponse mockResponse = new MockResponse().setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, // @formatter:off
MockResponse mockResponse = new MockResponse()
.setBody(body)
.setHeader(HttpHeaders.CONTENT_TYPE,
MediaType.APPLICATION_JSON_VALUE); MediaType.APPLICATION_JSON_VALUE);
this.server.enqueue(mockResponse); this.server.enqueue(mockResponse);
return ClientRegistrations.fromOidcIssuerLocation(this.issuer).clientId("client-id") return ClientRegistrations.fromOidcIssuerLocation(this.issuer)
.clientId("client-id")
.clientSecret("client-secret"); .clientSecret("client-secret");
// @formatter:on
} }
private ClientRegistration.Builder registrationOAuth2(String path, String body) throws Exception { private ClientRegistration.Builder registrationOAuth2(String path, String body) throws Exception {
@ -380,7 +437,11 @@ public class ClientRegistrationsTests {
} }
}; };
this.server.setDispatcher(dispatcher); this.server.setDispatcher(dispatcher);
return ClientRegistrations.fromIssuerLocation(this.issuer).clientId("client-id").clientSecret("client-secret"); // @formatter:off
return ClientRegistrations.fromIssuerLocation(this.issuer)
.clientId("client-id")
.clientSecret("client-secret");
// @formatter:on
} }
private String createIssuerFromServer(String path) { private String createIssuerFromServer(String path) {
@ -416,8 +477,11 @@ public class ClientRegistrationsTests {
} }
private MockResponse buildSuccessMockResponse(String body) { private MockResponse buildSuccessMockResponse(String body) {
return new MockResponse().setResponseCode(200).setBody(body).setHeader(HttpHeaders.CONTENT_TYPE, // @formatter:off
MediaType.APPLICATION_JSON_VALUE); return new MockResponse().setResponseCode(200)
.setBody(body)
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE);
// @formatter:on
} }
} }

View File

@ -75,8 +75,11 @@ public class InMemoryReactiveClientRegistrationRepositoryTests {
@Test @Test
public void findByRegistrationIdWhenValidIdThenFound() { public void findByRegistrationIdWhenValidIdThenFound() {
// @formatter:off
StepVerifier.create(this.repository.findByRegistrationId(this.registration.getRegistrationId())) StepVerifier.create(this.repository.findByRegistrationId(this.registration.getRegistrationId()))
.expectNext(this.registration).verifyComplete(); .expectNext(this.registration)
.verifyComplete();
// @formatter:on
} }
@Test @Test

View File

@ -29,39 +29,61 @@ public final class TestClientRegistrations {
} }
public static ClientRegistration.Builder clientRegistration() { public static ClientRegistration.Builder clientRegistration() {
// @formatter:off
return ClientRegistration.withRegistrationId("registration-id") return ClientRegistration.withRegistrationId("registration-id")
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).scope("read:user") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://example.com/login/oauth/authorize") .authorizationUri("https://example.com/login/oauth/authorize")
.tokenUri("https://example.com/login/oauth/access_token").jwkSetUri("https://example.com/oauth2/jwk") .tokenUri("https://example.com/login/oauth/access_token")
.issuerUri("https://example.com").userInfoUri("https://api.example.com/user") .jwkSetUri("https://example.com/oauth2/jwk")
.userNameAttributeName("id").clientName("Client Name").clientId("client-id") .issuerUri("https://example.com")
.userInfoUri("https://api.example.com/user")
.userNameAttributeName("id")
.clientName("Client Name")
.clientId("client-id")
.clientSecret("client-secret"); .clientSecret("client-secret");
// @formatter:on
} }
public static ClientRegistration.Builder clientRegistration2() { public static ClientRegistration.Builder clientRegistration2() {
// @formatter:off
return ClientRegistration.withRegistrationId("registration-id-2") return ClientRegistration.withRegistrationId("registration-id-2")
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).scope("read:user") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://example.com/login/oauth/authorize") .authorizationUri("https://example.com/login/oauth/authorize")
.tokenUri("https://example.com/login/oauth/access_token").userInfoUri("https://api.example.com/user") .tokenUri("https://example.com/login/oauth/access_token")
.userNameAttributeName("id").clientName("Client Name").clientId("client-id-2") .userInfoUri("https://api.example.com/user")
.userNameAttributeName("id")
.clientName("Client Name")
.clientId("client-id-2")
.clientSecret("client-secret"); .clientSecret("client-secret");
// @formatter:on
} }
public static ClientRegistration.Builder clientCredentials() { public static ClientRegistration.Builder clientCredentials() {
return clientRegistration().registrationId("client-credentials").clientId("client-id") // @formatter:off
return clientRegistration()
.registrationId("client-credentials")
.clientId("client-id")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS); .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS);
// @formatter:on
} }
public static ClientRegistration.Builder password() { public static ClientRegistration.Builder password() {
// @formatter:off
return ClientRegistration.withRegistrationId("password") return ClientRegistration.withRegistrationId("password")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.PASSWORD).scope("read", "write") .authorizationGrantType(AuthorizationGrantType.PASSWORD)
.tokenUri("https://example.com/login/oauth/access_token").clientName("Client Name") .scope("read", "write")
.clientId("client-id").clientSecret("client-secret"); .tokenUri("https://example.com/login/oauth/access_token")
.clientName("Client Name")
.clientId("client-id")
.clientSecret("client-secret");
// @formatter:on
} }
} }

View File

@ -69,7 +69,10 @@ public class CustomUserTypesOAuth2UserServiceTests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
String registrationId = "client-registration-id-1"; String registrationId = "client-registration-id-1";
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().registrationId(registrationId); // @formatter:off
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
.registrationId(registrationId);
// @formatter:on
this.accessToken = TestOAuth2AccessTokens.noScopes(); this.accessToken = TestOAuth2AccessTokens.noScopes();
Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>(); Map<String, Class<? extends OAuth2User>> customUserTypes = new HashMap<>();
customUserTypes.put(registrationId, CustomOAuth2User.class); customUserTypes.put(registrationId, CustomOAuth2User.class);
@ -113,16 +116,25 @@ public class CustomUserTypesOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() { public void loadUserWhenCustomUserTypeNotFoundThenReturnNull() {
// @formatter:off
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration() ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration()
.registrationId("other-client-registration-id-1").build(); .registrationId("other-client-registration-id-1")
.build();
// @formatter:on
OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); OAuth2User user = this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
assertThat(user).isNull(); assertThat(user).isNull();
} }
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" + " \"id\": \"12345\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"login\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; String userInfoResponse = "{\n"
+ " \"id\": \"12345\",\n"
+ " \"name\": \"first last\",\n"
+ " \"login\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -142,9 +154,15 @@ public class CustomUserTypesOAuth2UserServiceTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString( this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoResponse = "{\n" + " \"id\": \"12345\",\n" + " \"name\": \"first last\",\n" // @formatter:off
+ " \"login\": \"user1\",\n" + " \"email\": \"user1@example.com\"\n"; String userInfoResponse = "{\n"
// "}\n"; // Make the JSON invalid/malformed + " \"id\": \"12345\",\n"
+ " \"name\": \"first last\",\n"
+ " \"login\": \"user1\",\n"
+ " \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build(); ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri).build();
@ -173,9 +191,12 @@ public class CustomUserTypesOAuth2UserServiceTests {
} }
private ClientRegistration.Builder withRegistrationId(String registrationId) { private ClientRegistration.Builder withRegistrationId(String registrationId) {
// @formatter:off
return ClientRegistration.withRegistrationId(registrationId) return ClientRegistration.withRegistrationId(registrationId)
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).clientId("client") .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.clientId("client")
.tokenUri("/token"); .tokenUri("/token");
// @formatter:on
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {

View File

@ -80,8 +80,11 @@ public class DefaultOAuth2UserServiceTests {
public void setup() throws Exception { public void setup() throws Exception {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration().userInfoUri(null) // @formatter:off
this.clientRegistrationBuilder = TestClientRegistrations.clientRegistration()
.userInfoUri(null)
.userNameAttributeName(null); .userNameAttributeName(null);
// @formatter:on
this.accessToken = TestOAuth2AccessTokens.noScopes(); this.accessToken = TestOAuth2AccessTokens.noScopes();
} }
@ -120,16 +123,26 @@ public class DefaultOAuth2UserServiceTests {
public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString("missing_user_name_attribute")); this.exception.expectMessage(containsString("missing_user_name_attribute"));
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri("https://provider.com/user") // @formatter:off
ClientRegistration clientRegistration = this.clientRegistrationBuilder
.userInfoUri("https://provider.com/user")
.build(); .build();
// @formatter:on
this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken)); this.userService.loadUser(new OAuth2UserRequest(clientRegistration, this.accessToken));
} }
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"user-name\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -155,10 +168,16 @@ public class DefaultOAuth2UserServiceTests {
this.exception.expect(OAuth2AuthenticationException.class); this.exception.expect(OAuth2AuthenticationException.class);
this.exception.expectMessage(containsString( this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n"; + " \"user-name\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed // "}\n"; // Make the JSON invalid/malformed
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -190,7 +209,11 @@ public class DefaultOAuth2UserServiceTests {
this.exception.expectMessage(containsString( this.exception.expectMessage(containsString(
"[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource")); "[invalid_user_info_response] An error occurred while attempting to retrieve the UserInfo Resource"));
this.exception.expectMessage(containsString("Error Code: invalid_token")); this.exception.expectMessage(containsString("Error Code: invalid_token"));
String userInfoErrorResponse = "{\n" + " \"error\": \"invalid_token\"\n" + "}\n"; // @formatter:off
String userInfoErrorResponse = "{\n"
+ " \"error\": \"invalid_token\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400)); this.server.enqueue(jsonResponse(userInfoErrorResponse).setResponseCode(400));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -224,9 +247,16 @@ public class DefaultOAuth2UserServiceTests {
// gh-5294 // gh-5294
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception { public void loadUserWhenUserInfoSuccessResponseThenAcceptHeaderJson() throws Exception {
String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"user-name\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -239,9 +269,16 @@ public class DefaultOAuth2UserServiceTests {
// gh-5500 // gh-5500
@Test @Test
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"user-name\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)
@ -257,9 +294,16 @@ public class DefaultOAuth2UserServiceTests {
// gh-5500 // gh-5500
@Test @Test
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
String userInfoResponse = "{\n" + " \"user-name\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"user-name\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(userInfoResponse)); this.server.enqueue(jsonResponse(userInfoResponse));
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri) ClientRegistration clientRegistration = this.clientRegistrationBuilder.userInfoUri(userInfoUri)

View File

@ -78,7 +78,10 @@ public class DefaultReactiveOAuth2UserServiceTests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
String userInfoUri = this.server.url("/user").toString(); String userInfoUri = this.server.url("/user").toString();
this.clientRegistration = TestClientRegistrations.clientRegistration().userInfoUri(userInfoUri); // @formatter:off
this.clientRegistration = TestClientRegistrations.clientRegistration()
.userInfoUri(userInfoUri);
// @formatter:on
} }
@After @After
@ -103,16 +106,28 @@ public class DefaultReactiveOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserNameAttributeNameIsNullThenThrowOAuth2AuthenticationException() {
this.clientRegistration.userNameAttributeName(null); this.clientRegistration.userNameAttributeName(null);
StepVerifier.create(this.userService.loadUser(oauth2UserRequest())).expectErrorSatisfies((ex) -> assertThat(ex) // @formatter:off
.isInstanceOf(OAuth2AuthenticationException.class).hasMessageContaining("missing_user_name_attribute")) StepVerifier.create(this.userService.loadUser(oauth2UserRequest()))
.expectErrorSatisfies((ex) -> assertThat(ex)
.isInstanceOf(OAuth2AuthenticationException.class)
.hasMessageContaining("missing_user_name_attribute")
)
.verify(); .verify();
// @formatter:on
} }
@Test @Test
public void loadUserWhenUserInfoSuccessResponseThenReturnUser() { public void loadUserWhenUserInfoSuccessResponseThenReturnUser() {
String userInfoResponse = "{\n" + " \"id\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"id\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
enqueueApplicationJsonBody(userInfoResponse); enqueueApplicationJsonBody(userInfoResponse);
OAuth2User user = this.userService.loadUser(oauth2UserRequest()).block(); OAuth2User user = this.userService.loadUser(oauth2UserRequest()).block();
assertThat(user.getName()).isEqualTo("user1"); assertThat(user.getName()).isEqualTo("user1");
@ -134,9 +149,16 @@ public class DefaultReactiveOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception { public void loadUserWhenAuthenticationMethodHeaderSuccessResponseThenHttpMethodGet() throws Exception {
this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.HEADER); this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.HEADER);
String userInfoResponse = "{\n" + " \"id\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"id\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
enqueueApplicationJsonBody(userInfoResponse); enqueueApplicationJsonBody(userInfoResponse);
this.userService.loadUser(oauth2UserRequest()).block(); this.userService.loadUser(oauth2UserRequest()).block();
RecordedRequest request = this.server.takeRequest(); RecordedRequest request = this.server.takeRequest();
@ -150,9 +172,16 @@ public class DefaultReactiveOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception { public void loadUserWhenAuthenticationMethodFormSuccessResponseThenHttpMethodPost() throws Exception {
this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.FORM); this.clientRegistration.userInfoAuthenticationMethod(AuthenticationMethod.FORM);
String userInfoResponse = "{\n" + " \"id\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n" + "}\n"; + " \"id\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n"
+ "}\n";
// @formatter:on
enqueueApplicationJsonBody(userInfoResponse); enqueueApplicationJsonBody(userInfoResponse);
this.userService.loadUser(oauth2UserRequest()).block(); this.userService.loadUser(oauth2UserRequest()).block();
RecordedRequest request = this.server.takeRequest(); RecordedRequest request = this.server.takeRequest();
@ -164,10 +193,16 @@ public class DefaultReactiveOAuth2UserServiceTests {
@Test @Test
public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() { public void loadUserWhenUserInfoSuccessResponseInvalidThenThrowOAuth2AuthenticationException() {
String userInfoResponse = "{\n" + " \"id\": \"user1\",\n" + " \"first-name\": \"first\",\n" // @formatter:off
+ " \"last-name\": \"last\",\n" + " \"middle-name\": \"middle\",\n" String userInfoResponse = "{\n"
+ " \"address\": \"address\",\n" + " \"email\": \"user1@example.com\"\n"; + " \"id\": \"user1\",\n"
+ " \"first-name\": \"first\",\n"
+ " \"last-name\": \"last\",\n"
+ " \"middle-name\": \"middle\",\n"
+ " \"address\": \"address\",\n"
+ " \"email\": \"user1@example.com\"\n";
// "}\n"; // Make the JSON invalid/malformed // "}\n"; // Make the JSON invalid/malformed
// @formatter:on
enqueueApplicationJsonBody(userInfoResponse); enqueueApplicationJsonBody(userInfoResponse);
assertThatExceptionOfType(OAuth2AuthenticationException.class) assertThatExceptionOfType(OAuth2AuthenticationException.class)
.isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block()) .isThrownBy(() -> this.userService.loadUser(oauth2UserRequest()).block())

View File

@ -48,12 +48,19 @@ public class OAuth2UserRequestTests {
@Before @Before
public void setUp() { public void setUp() {
this.clientRegistration = ClientRegistration.withRegistrationId("registration-1").clientId("client-1") // @formatter:off
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) this.clientRegistration = ClientRegistration.withRegistrationId("registration-1")
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).redirectUri("https://client.com") .clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("https://client.com")
.scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))) .scope(new LinkedHashSet<>(Arrays.asList("scope1", "scope2")))
.authorizationUri("https://provider.com/oauth2/authorization") .authorizationUri("https://provider.com/oauth2/authorization")
.tokenUri("https://provider.com/oauth2/token").clientName("Client 1").build(); .tokenUri("https://provider.com/oauth2/token")
.clientName("Client 1")
.build();
// @formatter:on
this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(), this.accessToken = new OAuth2AccessToken(OAuth2AccessToken.TokenType.BEARER, "access-token-1234", Instant.now(),
Instant.now().plusSeconds(60), new LinkedHashSet<>(Arrays.asList("scope1", "scope2"))); Instant.now().plusSeconds(60), new LinkedHashSet<>(Arrays.asList("scope1", "scope2")));
this.additionalParameters = new HashMap<>(); this.additionalParameters = new HashMap<>();

View File

@ -72,11 +72,18 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration1 = TestClientRegistrations.clientRegistration().build();
this.registration2 = TestClientRegistrations.clientRegistration2().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build();
this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build(); this.fineRedirectUriTemplateRegistration = fineRedirectUriTemplateClientRegistration().build();
// @formatter:off
this.pkceRegistration = TestClientRegistrations.clientRegistration() this.pkceRegistration = TestClientRegistrations.clientRegistration()
.registrationId("pkce-client-registration-id").clientId("pkce-client-id") .registrationId("pkce-client-registration-id")
.clientAuthenticationMethod(ClientAuthenticationMethod.NONE).clientSecret(null).build(); .clientId("pkce-client-id")
this.oidcRegistration = TestClientRegistrations.clientRegistration().registrationId("oidc-registration-id") .clientAuthenticationMethod(ClientAuthenticationMethod.NONE)
.scope(OidcScopes.OPENID).build(); .clientSecret(null)
.build();
this.oidcRegistration = TestClientRegistrations.clientRegistration()
.registrationId("oidc-registration-id")
.scope(OidcScopes.OPENID)
.build();
// @formatter:on
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1,
this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration, this.registration2, this.fineRedirectUriTemplateRegistration, this.pkceRegistration,
this.oidcRegistration); this.oidcRegistration);
@ -134,8 +141,11 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
+ "-invalid"; + "-invalid";
MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri); MockHttpServletRequest request = new MockHttpServletRequest("GET", requestUri);
request.setServletPath(requestUri); request.setServletPath(requestUri);
assertThatIllegalArgumentException().isThrownBy(() -> this.resolver.resolve(request)).withMessage( // @formatter:off
"Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid"); assertThatIllegalArgumentException()
.isThrownBy(() -> this.resolver.resolve(request))
.withMessage("Invalid Client Registration with Id: " + clientRegistration.getRegistrationId() + "-invalid");
// @formatter:on
} }
@Test @Test
@ -483,14 +493,20 @@ public class DefaultOAuth2AuthorizationRequestResolverTests {
} }
private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() { private static ClientRegistration.Builder fineRedirectUriTemplateClientRegistration() {
// @formatter:off
return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration") return ClientRegistration.withRegistrationId("fine-redirect-uri-template-client-registration")
.redirectUri("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseScheme}://{baseHost}{basePort}{basePath}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).scope("read:user") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://example.com/login/oauth/authorize") .authorizationUri("https://example.com/login/oauth/authorize")
.tokenUri("https://example.com/login/oauth/access_token").userInfoUri("https://api.example.com/user") .tokenUri("https://example.com/login/oauth/access_token")
.userNameAttributeName("id").clientName("Fine Redirect Uri Template Client") .userInfoUri("https://api.example.com/user")
.clientId("fine-redirect-uri-template-client").clientSecret("client-secret"); .userNameAttributeName("id")
.clientName("Fine Redirect Uri Template Client")
.clientId("fine-redirect-uri-template-client")
.clientSecret("client-secret");
// @formatter:on
} }
} }

View File

@ -199,11 +199,16 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
@Test @Test
public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() { public void authorizeWhenClientRegistrationNotFoundThenThrowIllegalArgumentException() {
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId("invalid-registration-id").principal(this.principal).attributes((attrs) -> { .withClientRegistrationId("invalid-registration-id")
.principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest)) assertThatIllegalArgumentException().isThrownBy(() -> this.authorizedClientManager.authorize(authorizeRequest))
.withMessage("Could not find ClientRegistration with id 'invalid-registration-id'"); .withMessage("Could not find ClientRegistration with id 'invalid-registration-id'");
} }
@ -213,12 +218,16 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() { public void authorizeWhenNotAuthorizedAndUnsupportedProviderThenNotAuthorized() {
given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(this.clientRegistration.getRegistrationId())))
.willReturn(this.clientRegistration); .willReturn(this.clientRegistration);
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attributes((attrs) -> { .attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -238,12 +247,16 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
.willReturn(this.clientRegistration); .willReturn(this.clientRegistration);
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(this.authorizedClient); .willReturn(this.authorizedClient);
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attributes((attrs) -> { .attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(authorizeRequest)); verify(this.contextAttributesMapper).apply(eq(authorizeRequest));
@ -269,12 +282,16 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(reauthorizedClient); .willReturn(reauthorizedClient);
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attributes((attrs) -> { .attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(any()); verify(this.contextAttributesMapper).apply(any());
@ -309,12 +326,16 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
}); });
this.request.addParameter(OAuth2ParameterNames.USERNAME, "username"); this.request.addParameter(OAuth2ParameterNames.USERNAME, "username");
this.request.addParameter(OAuth2ParameterNames.PASSWORD, "password"); this.request.addParameter(OAuth2ParameterNames.PASSWORD, "password");
// @formatter:off
OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest OAuth2AuthorizeRequest authorizeRequest = OAuth2AuthorizeRequest
.withClientRegistrationId(this.clientRegistration.getRegistrationId()).principal(this.principal) .withClientRegistrationId(this.clientRegistration.getRegistrationId())
.principal(this.principal)
.attributes((attrs) -> { .attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
this.authorizedClientManager.authorize(authorizeRequest); this.authorizedClientManager.authorize(authorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
@ -327,11 +348,15 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@Test @Test
public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() { public void reauthorizeWhenUnsupportedProviderThenNotReauthorized() {
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).attributes((attrs) -> { .principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -352,11 +377,15 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken()); this.principal.getName(), TestOAuth2AccessTokens.noScopes(), TestOAuth2RefreshTokens.refreshToken());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willReturn(reauthorizedClient); .willReturn(reauthorizedClient);
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).attributes((attrs) -> { .principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest); OAuth2AuthorizedClient authorizedClient = this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest)); verify(this.contextAttributesMapper).apply(eq(reauthorizeRequest));
@ -381,11 +410,15 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
this.authorizedClientManager this.authorizedClientManager
.setContextAttributesMapper(new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper()); .setContextAttributesMapper(new DefaultOAuth2AuthorizedClientManager.DefaultContextAttributesMapper());
this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write"); this.request.addParameter(OAuth2ParameterNames.SCOPE, "read write");
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).attributes((attrs) -> { .principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
this.authorizedClientManager.authorize(reauthorizeRequest); this.authorizedClientManager.authorize(reauthorizeRequest);
verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture()); verify(this.authorizedClientProvider).authorize(this.authorizationContextCaptor.capture());
OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue(); OAuth2AuthorizationContext authorizationContext = this.authorizationContextCaptor.getValue();
@ -401,11 +434,15 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
this.clientRegistration.getRegistrationId()); this.clientRegistration.getRegistrationId());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willThrow(authorizationException); .willThrow(authorizationException);
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).attributes((attrs) -> { .principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationException.class) assertThatExceptionOfType(ClientAuthorizationException.class)
.isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException); .isEqualTo(authorizationException);
@ -421,11 +458,15 @@ public class DefaultOAuth2AuthorizedClientManagerTests {
new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId()); new OAuth2Error("non-matching-error-code", null, null), this.clientRegistration.getRegistrationId());
given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class))) given(this.authorizedClientProvider.authorize(any(OAuth2AuthorizationContext.class)))
.willThrow(authorizationException); .willThrow(authorizationException);
// @formatter:off
OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient) OAuth2AuthorizeRequest reauthorizeRequest = OAuth2AuthorizeRequest.withAuthorizedClient(this.authorizedClient)
.principal(this.principal).attributes((attrs) -> { .principal(this.principal)
.attributes((attrs) -> {
attrs.put(HttpServletRequest.class.getName(), this.request); attrs.put(HttpServletRequest.class.getName(), this.request);
attrs.put(HttpServletResponse.class.getName(), this.response); attrs.put(HttpServletResponse.class.getName(), this.response);
}).build(); })
.build();
// @formatter:on
assertThatExceptionOfType(ClientAuthorizationException.class) assertThatExceptionOfType(ClientAuthorizationException.class)
.isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest)) .isThrownBy(() -> this.authorizedClientManager.authorize(reauthorizeRequest))
.isEqualTo(authorizationException); .isEqualTo(authorizationException);

View File

@ -77,9 +77,13 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
public void setUp() { public void setUp() {
this.registration1 = TestClientRegistrations.clientRegistration().build(); this.registration1 = TestClientRegistrations.clientRegistration().build();
this.registration2 = TestClientRegistrations.clientRegistration2().build(); this.registration2 = TestClientRegistrations.clientRegistration2().build();
this.registration3 = TestClientRegistrations.clientRegistration().registrationId("registration-3") // @formatter:off
this.registration3 = TestClientRegistrations.clientRegistration()
.registrationId("registration-3")
.authorizationGrantType(AuthorizationGrantType.IMPLICIT) .authorizationGrantType(AuthorizationGrantType.IMPLICIT)
.redirectUri("{baseUrl}/authorize/oauth2/implicit/{registrationId}").build(); .redirectUri("{baseUrl}/authorize/oauth2/implicit/{registrationId}")
.build();
// @formatter:on
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1,
this.registration2, this.registration3); this.registration2, this.registration3);
this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository); this.filter = new OAuth2AuthorizationRequestRedirectFilter(this.clientRegistrationRepository);
@ -307,13 +311,18 @@ public class OAuth2AuthorizationRequestRedirectFilterTests {
OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request); OAuth2AuthorizationRequest defaultAuthorizationRequest = defaultAuthorizationRequestResolver.resolve(request);
Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters()); Map<String, Object> additionalParameters = new HashMap<>(defaultAuthorizationRequest.getAdditionalParameters());
additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName)); additionalParameters.put(loginHintParamName, request.getParameter(loginHintParamName));
// @formatter:off
String customAuthorizationRequestUri = UriComponentsBuilder String customAuthorizationRequestUri = UriComponentsBuilder
.fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri()) .fromUriString(defaultAuthorizationRequest.getAuthorizationRequestUri())
.queryParam(loginHintParamName, additionalParameters.get(loginHintParamName)).build(true).toUriString(); .queryParam(loginHintParamName, additionalParameters.get(loginHintParamName))
.build(true)
.toUriString();
OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest OAuth2AuthorizationRequest result = OAuth2AuthorizationRequest
.from(defaultAuthorizationRequestResolver.resolve(request)) .from(defaultAuthorizationRequestResolver.resolve(request))
.additionalParameters(Collections.singletonMap("idp", request.getParameter("idp"))) .additionalParameters(Collections.singletonMap("idp", request.getParameter("idp")))
.authorizationRequestUri(customAuthorizationRequestUri).build(); .authorizationRequestUri(customAuthorizationRequestUri)
.build();
// @formatter:on
given(resolver.resolve(any())).willReturn(result); given(resolver.resolve(any())).willReturn(result);
OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver); OAuth2AuthorizationRequestRedirectFilter filter = new OAuth2AuthorizationRequestRedirectFilter(resolver);
filter.doFilter(request, response, filterChain); filter.doFilter(request, response, filterChain);

View File

@ -219,14 +219,21 @@ public class OAuth2LoginAuthenticationFilterTests {
request.addParameter(OAuth2ParameterNames.STATE, "state"); request.addParameter(OAuth2ParameterNames.STATE, "state");
MockHttpServletResponse response = new MockHttpServletResponse(); MockHttpServletResponse response = new MockHttpServletResponse();
FilterChain filterChain = mock(FilterChain.class); FilterChain filterChain = mock(FilterChain.class);
// @formatter:off
ClientRegistration registrationNotFound = ClientRegistration.withRegistrationId("registration-not-found") ClientRegistration registrationNotFound = ClientRegistration.withRegistrationId("registration-not-found")
.clientId("client-1").clientSecret("secret") .clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{baseUrl}/login/oauth2/code/{registrationId}").scope("user") .redirectUri("{baseUrl}/login/oauth2/code/{registrationId}")
.authorizationUri("https://provider.com/oauth2/authorize").tokenUri("https://provider.com/oauth2/token") .scope("user")
.userInfoUri("https://provider.com/oauth2/user").userNameAttributeName("id").clientName("client-1") .authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build(); .build();
// @formatter:on
this.setUpAuthorizationRequest(request, response, registrationNotFound, state); this.setUpAuthorizationRequest(request, response, registrationNotFound, state);
this.filter.doFilter(request, response, filterChain); this.filter.doFilter(request, response, filterChain);
ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor ArgumentCaptor<AuthenticationException> authenticationExceptionArgCaptor = ArgumentCaptor

View File

@ -110,18 +110,32 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
SecurityContext securityContext = SecurityContextHolder.createEmptyContext(); SecurityContext securityContext = SecurityContextHolder.createEmptyContext();
securityContext.setAuthentication(this.authentication); securityContext.setAuthentication(this.authentication);
SecurityContextHolder.setContext(securityContext); SecurityContextHolder.setContext(securityContext);
this.registration1 = ClientRegistration.withRegistrationId("client1").clientId("client-1") // @formatter:off
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) this.registration1 = ClientRegistration.withRegistrationId("client1")
.clientId("client-1")
.clientSecret("secret")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE) .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.redirectUri("{baseUrl}/login/oauth2/code/{registrationId}").scope("user") .redirectUri("{baseUrl}/login/oauth2/code/{registrationId}")
.authorizationUri("https://provider.com/oauth2/authorize").tokenUri("https://provider.com/oauth2/token") .scope("user")
.userInfoUri("https://provider.com/oauth2/user").userNameAttributeName("id").clientName("client-1") .authorizationUri("https://provider.com/oauth2/authorize")
.tokenUri("https://provider.com/oauth2/token")
.userInfoUri("https://provider.com/oauth2/user")
.userNameAttributeName("id")
.clientName("client-1")
.build(); .build();
this.registration2 = ClientRegistration.withRegistrationId("client2").clientId("client-2") this.registration2 = ClientRegistration.withRegistrationId("client2")
.clientSecret("secret").clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientId("client-2")
.authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS).scope("read", "write") .clientSecret("secret")
.tokenUri("https://provider.com/oauth2/token").build(); .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
this.registration3 = TestClientRegistrations.password().registrationId("client3").build(); .authorizationGrantType(AuthorizationGrantType.CLIENT_CREDENTIALS)
.scope("read", "write")
.tokenUri("https://provider.com/oauth2/token")
.build();
this.registration3 = TestClientRegistrations.password()
.registrationId("client3")
.build();
// @formatter:on
this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1, this.clientRegistrationRepository = new InMemoryClientRegistrationRepository(this.registration1,
this.registration2, this.registration3); this.registration2, this.registration3);
this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class); this.authorizedClientRepository = mock(OAuth2AuthorizedClientRepository.class);

View File

@ -115,7 +115,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
this.server = new MockWebServer(); this.server = new MockWebServer();
this.server.start(); this.server.start();
this.serverUrl = this.server.url("/").toString(); this.serverUrl = this.server.url("/").toString();
this.webClient = WebClient.builder().filter(this.authorizedClientFilter).build(); // @formatter:off
this.webClient = WebClient.builder()
.filter(this.authorizedClientFilter)
.build();
// @formatter:on
this.authentication = new TestingAuthenticationToken("principal", "password"); this.authentication = new TestingAuthenticationToken("principal", "password");
this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/").build()).build(); this.exchange = MockServerWebExchange.builder(MockServerHttpRequest.get("/").build()).build();
} }
@ -127,22 +131,35 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenResponse = "{\n"
+ " \"scope\": \"read write\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl) ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl)
.build(); .build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId())))
.willReturn(Mono.just(clientRegistration)); .willReturn(Mono.just(clientRegistration));
this.webClient.get().uri(this.serverUrl) // @formatter:off
this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration.getRegistrationId())) .clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve().bodyToMono(String.class) .retrieve()
.bodyToMono(String.class)
.subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)).block(); .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
.block();
// @formatter:on
assertThat(this.server.getRequestCount()).isEqualTo(2); assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizedClient.class); .forClass(OAuth2AuthorizedClient.class);
@ -153,9 +170,17 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"refreshed-access-token\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenResponse = "{\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"access_token\": \"refreshed-access-token\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl) ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl)
@ -189,10 +214,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenResponse = "{\n"
+ " \"scope\": \"read write\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
// Client 1 // Client 1
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
@ -207,16 +240,24 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
.tokenUri(this.serverUrl).build(); .tokenUri(this.serverUrl).build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId())))
.willReturn(Mono.just(clientRegistration2)); .willReturn(Mono.just(clientRegistration2));
this.webClient.get().uri(this.serverUrl) // @formatter:off
this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration1.getRegistrationId())) .clientRegistrationId(clientRegistration1.getRegistrationId()))
.retrieve().bodyToMono(String.class) .retrieve()
.flatMap((response) -> this.webClient.get().uri(this.serverUrl) .bodyToMono(String.class)
.flatMap((response) -> this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration2.getRegistrationId())) .clientRegistrationId(clientRegistration2.getRegistrationId()))
.retrieve().bodyToMono(String.class)) .retrieve()
.bodyToMono(String.class)
)
.subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)).block(); .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication))
.block();
// @formatter:on
assertThat(this.server.getRequestCount()).isEqualTo(4); assertThat(this.server.getRequestCount()).isEqualTo(4);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizedClient.class); .forClass(OAuth2AuthorizedClient.class);
@ -232,10 +273,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
*/ */
@Test @Test
public void requestWhenUnauthorizedThenReAuthorize() { public void requestWhenUnauthorizedThenReAuthorize() {
String accessTokenResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenResponse = "{\n"
+ " \"scope\": \"read write\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(new MockResponse().setResponseCode(HttpStatus.UNAUTHORIZED.value())); this.server.enqueue(new MockResponse().setResponseCode(HttpStatus.UNAUTHORIZED.value()));
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
@ -250,15 +299,22 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
doReturn(Mono.just(authorizedClient)).doReturn(Mono.empty()).when(this.authorizedClientRepository) doReturn(Mono.just(authorizedClient)).doReturn(Mono.empty()).when(this.authorizedClientRepository)
.loadAuthorizedClient(eq(clientRegistration.getRegistrationId()), eq(this.authentication), .loadAuthorizedClient(eq(clientRegistration.getRegistrationId()), eq(this.authentication),
eq(this.exchange)); eq(this.exchange));
Mono<String> requestMono = this.webClient.get().uri(this.serverUrl) // @formatter:off
Mono<String> requestMono = this.webClient.get()
.uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction
.clientRegistrationId(clientRegistration.getRegistrationId())) .clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve().bodyToMono(String.class) .retrieve()
.bodyToMono(String.class)
.subscriberContext(Context.of(ServerWebExchange.class, this.exchange)) .subscriberContext(Context.of(ServerWebExchange.class, this.exchange))
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication)); .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(this.authentication));
// @formatter:on
// first try should fail, and remove the cached authorized client // first try should fail, and remove the cached authorized client
assertThatExceptionOfType(WebClientResponseException.class).isThrownBy(requestMono::block) // @formatter:off
assertThatExceptionOfType(WebClientResponseException.class)
.isThrownBy(requestMono::block)
.satisfies((ex) -> assertThat(ex.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED)); .satisfies((ex) -> assertThat(ex.getStatusCode()).isEqualTo(HttpStatus.UNAUTHORIZED));
// @formatter:on
assertThat(this.server.getRequestCount()).isEqualTo(1); assertThat(this.server.getRequestCount()).isEqualTo(1);
verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository, never()).saveAuthorizedClient(any(), any(), any());
verify(this.authorizedClientRepository).removeAuthorizedClient(eq(clientRegistration.getRegistrationId()), verify(this.authorizedClientRepository).removeAuthorizedClient(eq(clientRegistration.getRegistrationId()),
@ -274,7 +330,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionITests {
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); // @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
} }
} }

View File

@ -108,6 +108,8 @@ import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyZeroInteractions;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient;
/** /**
* @author Rob Winch * @author Rob Winch
@ -161,14 +163,17 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Before @Before
public void setup() { public void setup() {
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().authorizationCode() .builder()
.authorizationCode()
.refreshToken( .refreshToken(
(configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient)) (configurer) -> configurer.accessTokenResponseClient(this.refreshTokenTokenResponseClient))
.clientCredentials( .clientCredentials(
(configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient)) (configurer) -> configurer.accessTokenResponseClient(this.clientCredentialsTokenResponseClient))
.password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient)) .password((configurer) -> configurer.accessTokenResponseClient(this.passwordTokenResponseClient))
.build(); .build();
// @formatter:on
this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( this.authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository); this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); this.authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
@ -220,9 +225,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken); this.accessToken);
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); // @formatter:off
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange())
.block();
// @formatter:on
assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION)) assertThat(this.exchange.getRequest().headers().getFirst(HttpHeaders.AUTHORIZATION))
.isEqualTo("Bearer " + this.accessToken.getTokenValue()); .isEqualTo("Bearer " + this.accessToken.getTokenValue());
} }
@ -231,10 +239,12 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() { public void filterWhenExistingAuthorizationThenSingleAuthorizationHeader() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken); this.accessToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.header(HttpHeaders.AUTHORIZATION, "Existing") .header(HttpHeaders.AUTHORIZATION, "Existing")
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
// @formatter:on
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block();
HttpHeaders headers = this.exchange.getRequest().headers(); HttpHeaders headers = this.exchange.getRequest().headers();
assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue()); assertThat(headers.get(HttpHeaders.AUTHORIZATION)).containsOnly("Bearer " + this.accessToken.getTokenValue());
@ -242,8 +252,13 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test @Test
public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() { public void filterWhenClientCredentialsTokenExpiredThenGetNewToken() {
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token") // @formatter:off
.tokenType(OAuth2AccessToken.TokenType.BEARER).expiresIn(360).build(); OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse
.withToken("new-token")
.tokenType(OAuth2AccessToken.TokenType.BEARER)
.expiresIn(360)
.build();
// @formatter:on
given(this.clientCredentialsTokenResponseClient.getTokenResponse(any())) given(this.clientCredentialsTokenResponseClient.getTokenResponse(any()))
.willReturn(Mono.just(accessTokenResponse)); .willReturn(Mono.just(accessTokenResponse));
ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
@ -254,12 +269,15 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", accessToken, OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", accessToken,
null); null);
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange()).block(); .subscriberContext(serverWebExchange())
.block();
// @formatter:on
verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), eq(authentication), any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
@ -277,12 +295,15 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(registration, "principalName",
this.accessToken, null); this.accessToken, null);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange()).block(); .subscriberContext(serverWebExchange())
.block();
// @formatter:on
verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any()); verify(this.clientCredentialsTokenResponseClient, never()).getTokenResponse(any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
@ -305,13 +326,18 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
// @formatter:on
TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this"); TestingAuthenticationToken authentication = new TestingAuthenticationToken("test", "this");
// @formatter:off
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.subscriberContext(serverWebExchange()).block(); .subscriberContext(serverWebExchange())
.block();
// @formatter:on
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(), verify(this.authorizedClientRepository).saveAuthorizedClient(this.authorizedClientCaptor.capture(),
eq(authentication), any()); eq(authentication), any());
@ -339,10 +365,14 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
// @formatter:on
verify(this.refreshTokenTokenResponseClient).getTokenResponse(any()); verify(this.refreshTokenTokenResponseClient).getTokenResponse(any());
verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any()); verify(this.authorizedClientRepository).saveAuthorizedClient(any(), any(), any());
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
@ -358,10 +388,14 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken); this.accessToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
// @formatter:on
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0); ClientRequest request0 = requests.get(0);
@ -376,10 +410,14 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
// @formatter:on
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
ClientRequest request0 = requests.get(0); ClientRequest request0 = requests.get(0);
@ -398,9 +436,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
// @formatter:on
given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.UNAUTHORIZED.value()); given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.UNAUTHORIZED.value());
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block();
assertThat(publisherProbe.wasSubscribed()).isTrue(); assertThat(publisherProbe.wasSubscribed()).isTrue();
@ -427,18 +467,27 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
// @formatter:on
WebClientResponseException exception = WebClientResponseException.create(HttpStatus.UNAUTHORIZED.value(), WebClientResponseException exception = WebClientResponseException.create(HttpStatus.UNAUTHORIZED.value(),
HttpStatus.UNAUTHORIZED.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8); HttpStatus.UNAUTHORIZED.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8);
ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception);
assertThatExceptionOfType(WebClientResponseException.class).isThrownBy(() -> this.function // @formatter:off
.filter(request, throwingExchangeFunction).subscriberContext(serverWebExchange()).block()) assertThatExceptionOfType(WebClientResponseException.class)
.isThrownBy(() -> this.function
.filter(request, throwingExchangeFunction)
.subscriberContext(serverWebExchange())
.block()
)
.isEqualTo(exception); .isEqualTo(exception);
// @formatter:on
assertThat(publisherProbe.wasSubscribed()).isTrue(); assertThat(publisherProbe.wasSubscribed()).isTrue();
verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(), this.attributesCaptor.capture()); this.authenticationCaptor.capture(), this.attributesCaptor.capture());
// @formatter:off
assertThat(this.authorizationExceptionCaptor.getValue()) assertThat(this.authorizationExceptionCaptor.getValue())
.isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> { .isInstanceOfSatisfying(ClientAuthorizationException.class, (ex) -> {
assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId()); assertThat(ex.getClientRegistrationId()).isEqualTo(this.registration.getRegistrationId());
@ -446,6 +495,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
assertThat(ex).hasCause(exception); assertThat(ex).hasCause(exception);
assertThat(ex).hasMessageContaining("[invalid_token]"); assertThat(ex).hasMessageContaining("[invalid_token]");
}); });
// @formatter:on
assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class); assertThat(this.authenticationCaptor.getValue()).isInstanceOf(AnonymousAuthenticationToken.class);
assertThat(this.attributesCaptor.getValue()) assertThat(this.attributesCaptor.getValue())
.containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange)); .containsExactly(entry(ServerWebExchange.class.getName(), this.serverWebExchange));
@ -460,9 +510,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt()); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
// @formatter:on
given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.FORBIDDEN.value()); given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.FORBIDDEN.value());
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block();
assertThat(publisherProbe.wasSubscribed()).isTrue(); assertThat(publisherProbe.wasSubscribed()).isTrue();
@ -490,14 +542,20 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
WebClientResponseException exception = WebClientResponseException.create(HttpStatus.FORBIDDEN.value(), WebClientResponseException exception = WebClientResponseException.create(HttpStatus.FORBIDDEN.value(),
HttpStatus.FORBIDDEN.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8); HttpStatus.FORBIDDEN.getReasonPhrase(), HttpHeaders.EMPTY, new byte[0], StandardCharsets.UTF_8);
ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception); ExchangeFunction throwingExchangeFunction = (r) -> Mono.error(exception);
assertThatExceptionOfType(WebClientResponseException.class).isThrownBy(() -> this.function // @formatter:off
.filter(request, throwingExchangeFunction).subscriberContext(serverWebExchange()).block()) assertThatExceptionOfType(WebClientResponseException.class)
.isThrownBy(() -> this.function
.filter(request, throwingExchangeFunction)
.subscriberContext(serverWebExchange())
.block()
)
.isEqualTo(exception); .isEqualTo(exception);
// @formatter:on
assertThat(publisherProbe.wasSubscribed()).isTrue(); assertThat(publisherProbe.wasSubscribed()).isTrue();
verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(), verify(this.authorizationFailureHandler).onAuthorizationFailure(this.authorizationExceptionCaptor.capture(),
this.authenticationCaptor.capture(), this.attributesCaptor.capture()); this.authenticationCaptor.capture(), this.attributesCaptor.capture());
@ -523,7 +581,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", " String wwwAuthenticateHeader = "Bearer error=\"insufficient_scope\", "
+ "error_description=\"The request requires higher privileges than provided by the access token.\", " + "error_description=\"The request requires higher privileges than provided by the access token.\", "
@ -561,7 +619,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
OAuth2AuthorizationException exception = new OAuth2AuthorizationException( OAuth2AuthorizationException exception = new OAuth2AuthorizationException(
new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null)); new OAuth2Error(OAuth2ErrorCodes.INVALID_TOKEN, null, null));
@ -585,7 +643,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName",
this.accessToken, refreshToken); this.accessToken, refreshToken);
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction.oauth2AuthorizedClient(authorizedClient)) .attributes(oauth2AuthorizedClient(authorizedClient))
.build(); .build();
given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.BAD_REQUEST.value()); given(this.exchange.getResponse().rawStatusCode()).willReturn(HttpStatus.BAD_REQUEST.value());
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block();
@ -621,8 +679,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
.contentType(MediaType.APPLICATION_FORM_URLENCODED).body("username=username&password=password")) .contentType(MediaType.APPLICATION_FORM_URLENCODED).body("username=username&password=password"))
.build(); .build();
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(registration.getRegistrationId()))
.clientRegistrationId(registration.getRegistrationId()))
.build(); .build();
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)) .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
@ -646,8 +703,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any()))
.willReturn(Mono.just(authorizedClient)); .willReturn(Mono.just(authorizedClient));
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(this.registration.getRegistrationId()))
.clientRegistrationId(this.registration.getRegistrationId()))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block();
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
@ -710,8 +766,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
Collections.singletonMap("user", "rob"), "user"); Collections.singletonMap("user", "rob"), "user");
OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(), OAuth2AuthenticationToken authentication = new OAuth2AuthenticationToken(user, user.getAuthorities(),
"client-id"); "client-id");
// @formatter:off
this.function.filter(request, this.exchange) this.function.filter(request, this.exchange)
.subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication)).block(); .subscriberContext(ReactiveSecurityContextHolder.withAuthentication(authentication))
.block();
// @formatter:on
List<ClientRequest> requests = this.exchange.getRequests(); List<ClientRequest> requests = this.exchange.getRequests();
assertThat(requests).hasSize(1); assertThat(requests).hasSize(1);
verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository); verifyZeroInteractions(this.clientRegistrationRepository, this.authorizedClientRepository);
@ -724,11 +783,14 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken, refreshToken); this.accessToken, refreshToken);
given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any())) given(this.authorizedClientRepository.loadAuthorizedClient(any(), any(), any()))
.willReturn(Mono.just(authorizedClient)); .willReturn(Mono.just(authorizedClient));
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(this.registration.getRegistrationId()))
.clientRegistrationId(this.registration.getRegistrationId()))
.build(); .build();
this.function.filter(request, this.exchange).subscriberContext(serverWebExchange()).block(); this.function.filter(request, this.exchange)
.subscriberContext(serverWebExchange())
.block();
// @formatter:on
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(),
eq(this.serverWebExchange)); eq(this.serverWebExchange));
} }
@ -750,10 +812,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
ClientRegistration registration = TestClientRegistrations.clientCredentials().build(); ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId())))
.willReturn(Mono.just(registration)); .willReturn(Mono.just(registration));
// @formatter:off
ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com")) ClientRequest request = ClientRequest.create(HttpMethod.GET, URI.create("https://example.com"))
.attributes(ServerOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(registration.getRegistrationId()))
.clientRegistrationId(registration.getRegistrationId()))
.build(); .build();
// @formatter:on
this.function.filter(request, this.exchange).block(); this.function.filter(request, this.exchange).block();
verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any()); verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any());
verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any()); verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());

View File

@ -65,6 +65,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times; import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.springframework.security.oauth2.client.web.reactive.function.client.ServletOAuth2AuthorizedClientExchangeFilterFunction.clientRegistrationId;
/** /**
* @author Joe Grandja * @author Joe Grandja
@ -99,7 +100,11 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
// BlockHound to error. // BlockHound to error.
// NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine // NOTE: This is an issue with JDK 8. It's been tested on JDK 10 and works fine
// w/o this white-list. // w/o this white-list.
BlockHound.builder().allowBlockingCallsInside(Class.class.getName(), "getPackage").install(); // @formatter:off
BlockHound.builder()
.allowBlockingCallsInside(Class.class.getName(), "getPackage")
.install();
// @formatter:on
} }
@Before @Before
@ -148,10 +153,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() { public void requestWhenNotAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenResponse = "{\n"
+ " \"scope\": \"read write\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl) ClientRegistration clientRegistration = TestClientRegistrations.clientCredentials().tokenUri(this.serverUrl)
@ -159,8 +172,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration.getRegistrationId())))
.willReturn(clientRegistration); .willReturn(clientRegistration);
this.webClient.get().uri(this.serverUrl) this.webClient.get().uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
.clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve().bodyToMono(String.class).block(); .retrieve().bodyToMono(String.class).block();
assertThat(this.server.getRequestCount()).isEqualTo(2); assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
@ -172,9 +184,17 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() { public void requestWhenAuthorizedButExpiredThenRefreshAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"refreshed-access-token\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\"\n" + "}\n"; String accessTokenResponse = "{\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"access_token\": \"refreshed-access-token\",\n"
+ " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl) ClientRegistration clientRegistration = TestClientRegistrations.clientRegistration().tokenUri(this.serverUrl)
@ -191,8 +211,7 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient( doReturn(authorizedClient).when(this.authorizedClientRepository).loadAuthorizedClient(
eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request)); eq(clientRegistration.getRegistrationId()), eq(this.authentication), eq(this.request));
this.webClient.get().uri(this.serverUrl) this.webClient.get().uri(this.serverUrl)
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .attributes(clientRegistrationId(clientRegistration.getRegistrationId()))
.clientRegistrationId(clientRegistration.getRegistrationId()))
.retrieve().bodyToMono(String.class).block(); .retrieve().bodyToMono(String.class).block();
assertThat(this.server.getRequestCount()).isEqualTo(2); assertThat(this.server.getRequestCount()).isEqualTo(2);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
@ -206,10 +225,18 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
@Test @Test
public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() { public void requestMultipleWhenNoneAuthorizedThenAuthorizeAndSendRequest() {
String accessTokenResponse = "{\n" + " \"access_token\": \"access-token-1234\",\n" // @formatter:off
+ " \"token_type\": \"bearer\",\n" + " \"expires_in\": \"3600\",\n" String accessTokenResponse = "{\n"
+ " \"scope\": \"read write\"\n" + "}\n"; + " \"access_token\": \"access-token-1234\",\n"
String clientResponse = "{\n" + " \"attribute1\": \"value1\",\n" + " \"attribute2\": \"value2\"\n" + "}\n"; + " \"token_type\": \"bearer\",\n"
+ " \"expires_in\": \"3600\",\n"
+ " \"scope\": \"read write\"\n"
+ "}\n";
String clientResponse = "{\n"
+ " \"attribute1\": \"value1\",\n"
+ " \"attribute2\": \"value2\"\n"
+ "}\n";
// @formatter:on
// Client 1 // Client 1
this.server.enqueue(jsonResponse(accessTokenResponse)); this.server.enqueue(jsonResponse(accessTokenResponse));
this.server.enqueue(jsonResponse(clientResponse)); this.server.enqueue(jsonResponse(clientResponse));
@ -224,15 +251,22 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
.tokenUri(this.serverUrl).build(); .tokenUri(this.serverUrl).build();
given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId()))) given(this.clientRegistrationRepository.findByRegistrationId(eq(clientRegistration2.getRegistrationId())))
.willReturn(clientRegistration2); .willReturn(clientRegistration2);
this.webClient.get().uri(this.serverUrl) // @formatter:off
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction this.webClient.get()
.clientRegistrationId(clientRegistration1.getRegistrationId())) .uri(this.serverUrl)
.retrieve().bodyToMono(String.class) .attributes(clientRegistrationId(clientRegistration1.getRegistrationId()))
.flatMap((response) -> this.webClient.get().uri(this.serverUrl) .retrieve()
.attributes(ServletOAuth2AuthorizedClientExchangeFilterFunction .bodyToMono(String.class)
.clientRegistrationId(clientRegistration2.getRegistrationId())) .flatMap((response) -> this.webClient
.retrieve().bodyToMono(String.class)) .get()
.subscriberContext(context()).block(); .uri(this.serverUrl)
.attributes(clientRegistrationId(clientRegistration2.getRegistrationId()))
.retrieve()
.bodyToMono(String.class)
)
.subscriberContext(context())
.block();
// @formatter:on
assertThat(this.server.getRequestCount()).isEqualTo(4); assertThat(this.server.getRequestCount()).isEqualTo(4);
ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor ArgumentCaptor<OAuth2AuthorizedClient> authorizedClientCaptor = ArgumentCaptor
.forClass(OAuth2AuthorizedClient.class); .forClass(OAuth2AuthorizedClient.class);
@ -252,7 +286,11 @@ public class ServletOAuth2AuthorizedClientExchangeFilterFunctionITests {
} }
private MockResponse jsonResponse(String json) { private MockResponse jsonResponse(String json) {
return new MockResponse().setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE).setBody(json); // @formatter:off
return new MockResponse()
.setHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE)
.setBody(json);
// @formatter:on
} }
} }

View File

@ -80,8 +80,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
@Before @Before
public void setUp() { public void setUp() {
// @formatter:off
ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider = ReactiveOAuth2AuthorizedClientProviderBuilder
.builder().authorizationCode().refreshToken().clientCredentials().build(); .builder()
.authorizationCode()
.refreshToken()
.clientCredentials()
.build();
// @formatter:on
DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager( DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
this.clientRegistrationRepository, this.authorizedClientRepository); this.clientRegistrationRepository, this.authorizedClientRepository);
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider); authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);

View File

@ -93,8 +93,12 @@ public class DefaultServerOAuth2AuthorizationRequestResolverTests {
@Test @Test
public void resolveWhenForwardedHeadersClientRegistrationFoundThenWorks() { public void resolveWhenForwardedHeadersClientRegistrationFoundThenWorks() {
given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration)); given(this.clientRegistrationRepository.findByRegistrationId(any())).willReturn(Mono.just(this.registration));
ServerWebExchange exchange = MockServerWebExchange // @formatter:off
.from(MockServerHttpRequest.get("/oauth2/authorization/id").header("X-Forwarded-Host", "evil.com")); MockServerHttpRequest.BaseBuilder<?> httpRequest = MockServerHttpRequest
.get("/oauth2/authorization/id")
.header("X-Forwarded-Host", "evil.com");
// @formatter:on
ServerWebExchange exchange = MockServerWebExchange.from(httpRequest);
OAuth2AuthorizationRequest request = this.resolver.resolve(exchange).block(); OAuth2AuthorizationRequest request = this.resolver.resolve(exchange).block();
assertThat(request.getAuthorizationRequestUri()) assertThat(request.getAuthorizationRequestUri())
.matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&" .matches("https://example.com/login/oauth/authorize\\?" + "response_type=code&client_id=client-id&"

View File

@ -301,8 +301,12 @@ public class OAuth2AuthorizationCodeGrantWebFilterTests {
MockServerHttpRequest authorizationRequest, ClientRegistration registration) { MockServerHttpRequest authorizationRequest, ClientRegistration registration) {
Map<String, Object> attributes = new HashMap<>(); Map<String, Object> attributes = new HashMap<>();
attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId()); attributes.put(OAuth2ParameterNames.REGISTRATION_ID, registration.getRegistrationId());
return TestOAuth2AuthorizationRequests.request().attributes(attributes) // @formatter:off
.redirectUri(authorizationRequest.getURI().toString()).build(); return TestOAuth2AuthorizationRequests.request()
.attributes(attributes)
.redirectUri(authorizationRequest.getURI().toString())
.build();
// @formatter:on
} }
private static MockServerHttpRequest createAuthorizationRequest(String requestUri) { private static MockServerHttpRequest createAuthorizationRequest(String requestUri) {

View File

@ -41,7 +41,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.verifyNoInteractions;
/** /**
* @author Rob Winch * @author Rob Winch
@ -86,15 +86,23 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
@Test @Test
public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() { public void filterWhenDoesNotMatchThenClientRegistrationRepositoryNotSubscribed() {
this.client.get().exchange().expectStatus().isOk(); // @formatter:off
verifyZeroInteractions(this.clientRepository, this.authzRequestRepository); this.client.get()
.exchange()
.expectStatus().isOk();
// @formatter:on
verifyNoInteractions(this.clientRepository, this.authzRequestRepository);
} }
@Test @Test
public void filterWhenDoesMatchThenClientRegistrationRepositoryNotSubscribed() { public void filterWhenDoesMatchThenClientRegistrationRepositoryNotSubscribed() {
// @formatter:off
FluxExchangeResult<String> result = this.client.get() FluxExchangeResult<String> result = this.client.get()
.uri("https://example.com/oauth2/authorization/registration-id").exchange().expectStatus() .uri("https://example.com/oauth2/authorization/registration-id")
.is3xxRedirection().returnResult(String.class); .exchange()
.expectStatus().is3xxRedirection()
.returnResult(String.class);
// @formatter:on
result.assertWithDiagnostics(() -> { result.assertWithDiagnostics(() -> {
URI location = result.getResponseHeaders().getLocation(); URI location = result.getResponseHeaders().getLocation();
assertThat(location).hasScheme("https").hasHost("example.com").hasPath("/login/oauth/authorize") assertThat(location).hasScheme("https").hasHost("example.com").hasPath("/login/oauth/authorize")
@ -108,16 +116,23 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
// gh-5520 // gh-5520
@Test @Test
public void filterWhenDoesMatchThenResolveRedirectUriExpandedExcludesQueryString() { public void filterWhenDoesMatchThenResolveRedirectUriExpandedExcludesQueryString() {
// @formatter:off
FluxExchangeResult<String> result = this.client.get() FluxExchangeResult<String> result = this.client.get()
.uri("https://example.com/oauth2/authorization/registration-id?foo=bar").exchange().expectStatus() .uri("https://example.com/oauth2/authorization/registration-id?foo=bar").exchange().expectStatus()
.is3xxRedirection().returnResult(String.class); .is3xxRedirection().returnResult(String.class);
result.assertWithDiagnostics(() -> { result.assertWithDiagnostics(() -> {
URI location = result.getResponseHeaders().getLocation(); URI location = result.getResponseHeaders().getLocation();
assertThat(location).hasScheme("https").hasHost("example.com").hasPath("/login/oauth/authorize") assertThat(location)
.hasParameter("response_type", "code").hasParameter("client_id", "client-id") .hasScheme("https")
.hasParameter("scope", "read:user").hasParameter("state") .hasHost("example.com")
.hasPath("/login/oauth/authorize")
.hasParameter("response_type", "code")
.hasParameter("client_id", "client-id")
.hasParameter("scope", "read:user")
.hasParameter("state")
.hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id"); .hasParameter("redirect_uri", "https://example.com/login/oauth2/code/registration-id");
}); });
// @formatter:on
} }
@Test @Test
@ -125,9 +140,15 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
FilteringWebHandler webHandler = new FilteringWebHandler( FilteringWebHandler webHandler = new FilteringWebHandler(
(e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), (e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())),
Arrays.asList(this.filter)); Arrays.asList(this.filter));
this.client = WebTestClient.bindToWebHandler(webHandler).build(); // @formatter:off
FluxExchangeResult<String> result = this.client.get().uri("https://example.com/foo").exchange().expectStatus() this.client = WebTestClient.bindToWebHandler(webHandler)
.is3xxRedirection().returnResult(String.class); .build();
FluxExchangeResult<String> result = this.client.get()
.uri("https://example.com/foo")
.exchange()
.expectStatus().is3xxRedirection()
.returnResult(String.class);
// @formatter:on
} }
@Test @Test
@ -137,18 +158,29 @@ public class OAuth2AuthorizationRequestRedirectWebFilterTests {
FilteringWebHandler webHandler = new FilteringWebHandler( FilteringWebHandler webHandler = new FilteringWebHandler(
(e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())), (e) -> Mono.error(new ClientAuthorizationRequiredException(this.registration.getRegistrationId())),
Arrays.asList(this.filter)); Arrays.asList(this.filter));
this.client = WebTestClient.bindToWebHandler(webHandler).build(); // @formatter:off
this.client.get().uri("https://example.com/foo").exchange().expectStatus().is3xxRedirection() this.client = WebTestClient.bindToWebHandler(webHandler)
.build();
this.client.get()
.uri("https://example.com/foo")
.exchange()
.expectStatus().is3xxRedirection()
.returnResult(String.class); .returnResult(String.class);
// @formatter:on
verify(this.requestCache).saveRequest(any()); verify(this.requestCache).saveRequest(any());
} }
@Test @Test
public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() { public void filterWhenPathMatchesThenRequestSessionAttributeNotSaved() {
this.filter.setRequestCache(this.requestCache); this.filter.setRequestCache(this.requestCache);
this.client.get().uri("https://example.com/oauth2/authorization/registration-id").exchange().expectStatus() // @formatter:off
.is3xxRedirection().returnResult(String.class); this.client.get()
verifyZeroInteractions(this.requestCache); .uri("https://example.com/oauth2/authorization/registration-id")
.exchange()
.expectStatus().is3xxRedirection()
.returnResult(String.class);
// @formatter:on
verifyNoInteractions(this.requestCache);
} }
} }

View File

@ -58,18 +58,30 @@ public class ServerOAuth2AuthorizationCodeAuthenticationTokenConverterTests {
private String clientRegistrationId = "github"; private String clientRegistrationId = "github";
// @formatter:off
private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId) private ClientRegistration clientRegistration = ClientRegistration.withRegistrationId(this.clientRegistrationId)
.redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}") .redirectUri("{baseUrl}/{action}/oauth2/code/{registrationId}")
.clientAuthenticationMethod(ClientAuthenticationMethod.BASIC) .clientAuthenticationMethod(ClientAuthenticationMethod.BASIC)
.authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE).scope("read:user") .authorizationGrantType(AuthorizationGrantType.AUTHORIZATION_CODE)
.scope("read:user")
.authorizationUri("https://github.com/login/oauth/authorize") .authorizationUri("https://github.com/login/oauth/authorize")
.tokenUri("https://github.com/login/oauth/access_token").userInfoUri("https://api.github.com/user") .tokenUri("https://github.com/login/oauth/access_token")
.userNameAttributeName("id").clientName("GitHub").clientId("clientId").clientSecret("clientSecret").build(); .userInfoUri("https://api.github.com/user")
.userNameAttributeName("id")
.clientName("GitHub")
.clientId("clientId")
.clientSecret("clientSecret")
.build();
// @formatter:on
// @formatter:off
private OAuth2AuthorizationRequest.Builder authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() private OAuth2AuthorizationRequest.Builder authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize").clientId("client-id") .authorizationUri("https://example.com/oauth2/authorize")
.redirectUri("http://localhost/client-1").state("state") .clientId("client-id")
.redirectUri("http://localhost/client-1")
.state("state")
.attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId)); .attributes(Collections.singletonMap(OAuth2ParameterNames.REGISTRATION_ID, this.clientRegistrationId));
// @formatter:on
private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/"); private final MockServerHttpRequest.BaseBuilder<?> request = MockServerHttpRequest.get("/");

View File

@ -51,9 +51,14 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
private WebSessionOAuth2ServerAuthorizationRequestRepository repository = new WebSessionOAuth2ServerAuthorizationRequestRepository(); private WebSessionOAuth2ServerAuthorizationRequestRepository repository = new WebSessionOAuth2ServerAuthorizationRequestRepository();
// @formatter:off
private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode() private OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize").clientId("client-id") .authorizationUri("https://example.com/oauth2/authorize")
.redirectUri("http://localhost/client-1").state("state").build(); .clientId("client-id")
.redirectUri("http://localhost/client-1")
.state("state")
.build();
// @formatter:on
private ServerWebExchange exchange = MockServerWebExchange private ServerWebExchange exchange = MockServerWebExchange
.from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state")); .from(MockServerHttpRequest.get("/").queryParam(OAuth2ParameterNames.STATE, "state"));
@ -66,55 +71,80 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
@Test @Test
public void loadAuthorizationRequestWhenNoSessionThenEmpty() { public void loadAuthorizationRequestWhenNoSessionThenEmpty() {
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)).verifyComplete(); // @formatter:off
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
// @formatter:on
assertSessionStartedIs(false); assertSessionStartedIs(false);
} }
@Test @Test
public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() { public void loadAuthorizationRequestWhenSessionAndNoRequestThenEmpty() {
Mono<OAuth2AuthorizationRequest> setAttrThenLoad = this.exchange.getSession().map(WebSession::getAttributes) // @formatter:off
Mono<OAuth2AuthorizationRequest> setAttrThenLoad = this.exchange.getSession()
.map(WebSession::getAttributes)
.doOnNext((attrs) -> attrs.put("foo", "bar")) .doOnNext((attrs) -> attrs.put("foo", "bar"))
.then(this.repository.loadAuthorizationRequest(this.exchange)); .then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(setAttrThenLoad).verifyComplete(); StepVerifier.create(setAttrThenLoad)
.verifyComplete();
// @formatter:on
} }
@Test @Test
public void loadAuthorizationRequestWhenNoStateParamThenEmpty() { public void loadAuthorizationRequestWhenNoStateParamThenEmpty() {
this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/")); this.exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/"));
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository
.saveAuthorizationRequest(this.authorizationRequest, this.exchange) .saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.loadAuthorizationRequest(this.exchange)); .then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndLoad).verifyComplete(); StepVerifier.create(saveAndLoad)
.verifyComplete();
// @formatter:on
} }
@Test @Test
public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() { public void loadAuthorizationRequestWhenSavedThenAuthorizationRequest() {
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository Mono<OAuth2AuthorizationRequest> saveAndLoad = this.repository
.saveAuthorizationRequest(this.authorizationRequest, this.exchange) .saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.loadAuthorizationRequest(this.exchange)); .then(this.repository.loadAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndLoad).expectNext(this.authorizationRequest).verifyComplete(); StepVerifier.create(saveAndLoad)
.expectNext(this.authorizationRequest)
.verifyComplete();
// @formatter:on
} }
@Test @Test
public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() { public void loadAuthorizationRequestWhenMultipleSavedThenAuthorizationRequest() {
String oldState = "state0"; String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build(); .queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize").clientId("client-id") .authorizationUri("https://example.com/oauth2/authorize")
.redirectUri("http://localhost/client-1").state(oldState).build(); .clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
WebSessionManager sessionManager = (e) -> this.exchange.getSession(); WebSessionManager sessionManager = (e) -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository Mono<OAuth2AuthorizationRequest> saveAndSaveAndLoad = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.loadAuthorizationRequest(oldExchange)); .then(this.repository.loadAuthorizationRequest(oldExchange));
StepVerifier.create(saveAndSaveAndLoad).expectNext(oldAuthorizationRequest).verifyComplete(); StepVerifier.create(saveAndSaveAndLoad)
.expectNext(oldAuthorizationRequest)
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)) StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.expectNext(this.authorizationRequest).verifyComplete(); .expectNext(this.authorizationRequest)
.verifyComplete();
// @formatter:on
} }
@Test @Test
@ -147,58 +177,89 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
@Test @Test
public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() { public void removeAuthorizationRequestWhenPresentThenFoundAndRemoved() {
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository
.saveAuthorizationRequest(this.authorizationRequest, this.exchange) .saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.removeAuthorizationRequest(this.exchange)); .then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndRemove).expectNext(this.authorizationRequest).verifyComplete(); StepVerifier.create(saveAndRemove)
StepVerifier.create(this.exchange.getSession().map(WebSession::getAttributes).map(Map::isEmpty)) .expectNext(this.authorizationRequest)
.verifyComplete();
StepVerifier.create(this.exchange
.getSession()
.map(WebSession::getAttributes)
.map(Map::isEmpty)
)
.expectNext(true).verifyComplete(); .expectNext(true).verifyComplete();
// @formatter:on
} }
// gh-5599 // gh-5599
@Test @Test
public void removeAuthorizationRequestWhenStateMissingThenNoErrors() { public void removeAuthorizationRequestWhenStateMissingThenNoErrors() {
// @formatter:off
MockServerHttpRequest otherState = MockServerHttpRequest.get("/") MockServerHttpRequest otherState = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, "other").build(); .queryParam(OAuth2ParameterNames.STATE, "other")
ServerWebExchange otherStateExchange = this.exchange.mutate().request(otherState).build(); .build();
ServerWebExchange otherStateExchange = this.exchange.mutate()
.request(otherState)
.build();
Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository Mono<OAuth2AuthorizationRequest> saveAndRemove = this.repository
.saveAuthorizationRequest(this.authorizationRequest, this.exchange) .saveAuthorizationRequest(this.authorizationRequest, this.exchange)
.then(this.repository.removeAuthorizationRequest(otherStateExchange)); .then(this.repository.removeAuthorizationRequest(otherStateExchange));
StepVerifier.create(saveAndRemove).verifyComplete(); StepVerifier.create(saveAndRemove)
.verifyComplete();
// @formatter:on
} }
@Test @Test
public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() { public void removeAuthorizationRequestWhenMultipleThenOnlyOneRemoved() {
String oldState = "state0"; String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build(); .queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize").clientId("client-id") .authorizationUri("https://example.com/oauth2/authorize")
.redirectUri("http://localhost/client-1").state(oldState).build(); .clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
WebSessionManager sessionManager = (e) -> this.exchange.getSession(); WebSessionManager sessionManager = (e) -> this.exchange.getSession();
this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(), this.exchange = new DefaultServerWebExchange(this.exchange.getRequest(), new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange)); .then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest).verifyComplete(); StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)).verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange)).expectNext(oldAuthorizationRequest)
.verifyComplete(); .verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(oldExchange))
.expectNext(oldAuthorizationRequest)
.verifyComplete();
// @formatter:on
} }
// gh-7327 // gh-7327
@Test @Test
public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() { public void removeAuthorizationRequestWhenMultipleThenRemovedAndSessionAttributeUpdated() {
String oldState = "state0"; String oldState = "state0";
// @formatter:off
MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/") MockServerHttpRequest oldRequest = MockServerHttpRequest.get("/")
.queryParam(OAuth2ParameterNames.STATE, oldState).build(); .queryParam(OAuth2ParameterNames.STATE, oldState)
.build();
OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode() OAuth2AuthorizationRequest oldAuthorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.authorizationUri("https://example.com/oauth2/authorize").clientId("client-id") .authorizationUri("https://example.com/oauth2/authorize")
.redirectUri("http://localhost/client-1").state(oldState).build(); .clientId("client-id")
.redirectUri("http://localhost/client-1")
.state(oldState)
.build();
// @formatter:on
Map<String, Object> sessionAttrs = spy(new HashMap<>()); Map<String, Object> sessionAttrs = spy(new HashMap<>());
WebSession session = mock(WebSession.class); WebSession session = mock(WebSession.class);
given(session.getAttributes()).willReturn(sessionAttrs); given(session.getAttributes()).willReturn(sessionAttrs);
@ -207,18 +268,27 @@ public class WebSessionOAuth2ServerAuthorizationRequestRepositoryTests {
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(), ServerWebExchange oldExchange = new DefaultServerWebExchange(oldRequest, new MockServerHttpResponse(),
sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver()); sessionManager, ServerCodecConfigurer.create(), new AcceptHeaderLocaleContextResolver());
// @formatter:off
Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository Mono<OAuth2AuthorizationRequest> saveAndSaveAndRemove = this.repository
.saveAuthorizationRequest(oldAuthorizationRequest, oldExchange) .saveAuthorizationRequest(oldAuthorizationRequest, oldExchange)
.then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange)) .then(this.repository.saveAuthorizationRequest(this.authorizationRequest, this.exchange))
.then(this.repository.removeAuthorizationRequest(this.exchange)); .then(this.repository.removeAuthorizationRequest(this.exchange));
StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest).verifyComplete(); StepVerifier.create(saveAndSaveAndRemove).expectNext(this.authorizationRequest)
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange)).verifyComplete(); .verifyComplete();
StepVerifier.create(this.repository.loadAuthorizationRequest(this.exchange))
.verifyComplete();
// @formatter:on
verify(sessionAttrs, times(3)).put(any(), any()); verify(sessionAttrs, times(3)).put(any(), any());
} }
private void assertSessionStartedIs(boolean expected) { private void assertSessionStartedIs(boolean expected) {
Mono<Boolean> isStarted = this.exchange.getSession().map(WebSession::isStarted); // @formatter:off
StepVerifier.create(isStarted).expectNext(expected).verifyComplete(); Mono<Boolean> isStarted = this.exchange.getSession()
.map(WebSession::isStarted);
StepVerifier.create(isStarted)
.expectNext(expected)
.verifyComplete();
// @formatter:on
} }
} }

View File

@ -90,12 +90,19 @@ public class OAuth2LoginAuthenticationWebFilterTests {
DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"), DefaultOAuth2User user = new DefaultOAuth2User(AuthorityUtils.createAuthorityList("ROLE_USER"),
Collections.singletonMap("user", "rob"), "user"); Collections.singletonMap("user", "rob"), "user");
ClientRegistration clientRegistration = this.registration.build(); ClientRegistration clientRegistration = this.registration.build();
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode().state("state") // @formatter:off
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest.authorizationCode()
.state("state")
.clientId(clientRegistration.getClientId()) .clientId(clientRegistration.getClientId())
.authorizationUri(clientRegistration.getProviderDetails().getAuthorizationUri()) .authorizationUri(clientRegistration.getProviderDetails()
.redirectUri(clientRegistration.getRedirectUri()).scopes(clientRegistration.getScopes()).build(); .getAuthorizationUri())
.redirectUri(clientRegistration.getRedirectUri())
.scopes(clientRegistration.getScopes())
.build();
OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr OAuth2AuthorizationResponse authorizationResponse = this.authorizationResponseBldr
.redirectUri(clientRegistration.getRedirectUri()).build(); .redirectUri(clientRegistration.getRedirectUri())
.build();
// @formatter:on
OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest, OAuth2AuthorizationExchange authorizationExchange = new OAuth2AuthorizationExchange(authorizationRequest,
authorizationResponse); authorizationResponse);
return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user, return new OAuth2LoginAuthenticationToken(clientRegistration, authorizationExchange, user,