OAuth2AuthorizedClientResolver

Extract out a private API for shared code between the argument resolver
and WebClient support. This makes it easier to make changes in both
locations. Later we will extract this out so it is not a copy/paste
effort.

Issue: gh-4921
This commit is contained in:
Rob Winch 2018-09-05 20:38:12 -05:00
parent 23726abb1e
commit 438d2911fb
7 changed files with 457 additions and 227 deletions

View File

@ -21,6 +21,7 @@ import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.ImportSelector;
import org.springframework.core.type.AnnotationMetadata;
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
@ -53,17 +54,25 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
@Configuration
static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
@Override
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
if (this.authorizedClientRepository != null) {
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(getAuthorizedClientRepository()));
if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) {
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository()));
}
}
@Autowired(required = false)
public void setClientRegistrationRepository(
ReactiveClientRegistrationRepository clientRegistrationRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
}
@Autowired(required = false)
public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.authorizedClientRepository = authorizedClientRepository;

View File

@ -0,0 +1,185 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Optional;
/**
* @author Rob Winch
* @since 5.1
*/
class OAuth2AuthorizedClientResolver {
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER"));
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient();
private boolean defaultOAuth2AuthorizedClient;
public OAuth2AuthorizedClientResolver(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
/**
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
}
/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
}
Mono<Request> createDefaultedRequest(String clientRegistrationId,
Authentication authentication, ServerWebExchange exchange) {
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
.switchIfEmpty(currentAuthentication());
Mono<String> defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId)
.switchIfEmpty(clientRegistrationId(defaultedAuthentication));
Mono<Optional<ServerWebExchange>> defaultedExchange = Mono.justOrEmpty(exchange)
.switchIfEmpty(currentServerWebExchange()).map(Optional::of)
.defaultIfEmpty(Optional.empty());
return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
.map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
}
Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
String clientRegistrationId = request.getClientRegistrationId();
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
}
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse));
}
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, authentication.getName(), tokenResponse.getAccessToken());
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange)
.thenReturn(authorizedClient);
}
/**
* Attempts to load the client registration id from the current {@link Authentication}
* @return
*/
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
return authentication
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
}
private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
}
private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
static class Request {
private final String clientRegistrationId;
private final Authentication authentication;
private final ServerWebExchange exchange;
public Request(String clientRegistrationId, Authentication authentication,
ServerWebExchange exchange) {
this.clientRegistrationId = clientRegistrationId;
this.authentication = authentication;
this.exchange = exchange;
}
public String getClientRegistrationId() {
return this.clientRegistrationId;
}
public Authentication getAuthentication() {
return this.authentication;
}
public ServerWebExchange getExchange() {
return this.exchange;
}
}
}

View File

@ -16,28 +16,21 @@
package org.springframework.security.oauth2.client.web.reactive.function.client;
import com.sun.security.ntlm.Server;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
@ -51,9 +44,7 @@ import java.net.URI;
import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
@ -88,20 +79,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
private boolean defaultOAuth2AuthorizedClient;
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
new WebClientReactiveClientCredentialsTokenResponseClient();
private ReactiveClientRegistrationRepository clientRegistrationRepository;
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
}
/**
@ -142,6 +126,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
}
private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) {
return (OAuth2AuthorizedClient) request.attributes().get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
}
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
@ -166,6 +153,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
}
private static ServerWebExchange serverWebExchange(ClientRequest request) {
return (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
}
/**
* Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
* be used to look up the {@link OAuth2AuthorizedClient}.
@ -178,6 +169,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
}
private static String clientRegistrationId(ClientRequest request) {
OAuth2AuthorizedClient authorizedClient = oauth2AuthorizedClient(request);
if (authorizedClient != null) {
return authorizedClient.getClientRegistration().getRegistrationId();
}
return (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
}
/**
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
@ -186,7 +185,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
* Default is false.
*/
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient);
}
/**
@ -196,8 +195,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
*/
public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
}
/**
@ -212,128 +210,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
@Override
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
return authorizedClient(request)
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
return authorizedClient(request, next)
.map(authorizedClient -> bearer(request, authorizedClient))
.flatMap(next::exchange)
.switchIfEmpty(next.exchange(request));
}
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
return Mono.justOrEmpty(exchange)
.switchIfEmpty(serverWebExchange());
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next) {
OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
return Mono.justOrEmpty(authorizedClientFromAttrs)
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request)))
.flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient));
}
private Mono<ServerWebExchange> serverWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(ClientRequest request) {
return createRequest(request)
.flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r));
}
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
.map(OAuth2AuthorizedClient.class::cast);
return Mono.justOrEmpty(attribute)
.switchIfEmpty(findAuthorizedClientByRegistrationId(request));
private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest request) {
String clientRegistrationId = clientRegistrationId(request);
Authentication authentication = null;
ServerWebExchange exchange = serverWebExchange(request);
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange);
}
private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
if (this.authorizedClientRepository == null) {
return Mono.empty();
}
return currentAuthentication()
.flatMap(principal -> clientRegistrationId(request, principal)
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
);
}
private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
.cast(String.class)
.switchIfEmpty(clientRegistrationId(authentication));
}
private Mono<String> clientRegistrationId(Authentication authentication) {
return Mono.justOrEmpty(authentication)
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
}
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
ServerWebExchange exchange, Authentication principal) {
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
}
private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
}
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
return currentAuthentication()
.flatMap(principal -> {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, (principal != null ?
principal.getName() :
"anonymousUser"),
tokenResponse.getAccessToken());
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
.thenReturn(authorizedClient);
});
}
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
if (shouldRefresh(authorizedClient)) {
return serverWebExchange(request)
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
return createRequest(request)
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
}
return Mono.just(authorizedClient);
}
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
ServerWebExchange exchange = r.getExchange();
Authentication authentication = r.getAuthentication();
ClientRegistration clientRegistration = authorizedClient
.getClientRegistration();
String tokenUri = clientRegistration
.getProviderDetails().getTokenUri();
ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
.build();
return next.exchange(request)
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
return next.exchange(refreshRequest)
.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
.flatMap(result -> currentAuthentication()
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
.thenReturn(result));
}
private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
}
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
if (this.authorizedClientRepository == null) {
return false;
@ -361,52 +290,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
.with("refresh_token", refreshToken);
}
private static class PrincipalNameAuthentication implements Authentication {
private final String username;
private PrincipalNameAuthentication(String username) {
this.username = username;
}
@Override
public Collection<? extends GrantedAuthority> getAuthorities() {
throw unsupported();
}
@Override
public Object getCredentials() {
throw unsupported();
}
@Override
public Object getDetails() {
throw unsupported();
}
@Override
public Object getPrincipal() {
throw unsupported();
}
@Override
public boolean isAuthenticated() {
throw unsupported();
}
@Override
public void setAuthenticated(boolean isAuthenticated)
throws IllegalArgumentException {
throw unsupported();
}
@Override
public String getName() {
return this.username;
}
private UnsupportedOperationException unsupported() {
return new UnsupportedOperationException("Not Supported");
}
}
}

View File

@ -18,15 +18,9 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an
import org.springframework.core.MethodParameter;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@ -56,16 +50,18 @@ import reactor.core.publisher.Mono;
* @see RegisteredOAuth2AuthorizedClient
*/
public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
/**
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
*
* @param authorizedClientRepository the authorized client repository
*/
public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.authorizedClientRepository = authorizedClientRepository;
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true);
}
@Override
@ -80,41 +76,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
Mono<String> clientRegistrationId = Mono.justOrEmpty(authorizedClientAnnotation.registrationId())
.filter(id -> !StringUtils.isEmpty(id))
.switchIfEmpty(clientRegistrationId())
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException(
"Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."))));
String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ?
authorizedClientAnnotation.registrationId() : null;
Mono<Authentication> principal = ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(new AnonymousAuthenticationToken("key", "anonymous",
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")));
Mono<OAuth2AuthorizedClient> authorizedClient = Mono
.zip(clientRegistrationId, principal).switchIfEmpty(
clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException(
"Unable to resolve the Authorized Client with registration identifier \""
+ id
+ "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured."))))
.flatMap(zipped -> {
String registrationId = zipped.getT1();
Authentication authentication = zipped.getT2();
return this.authorizedClientRepository
.loadAuthorizedClient(registrationId, authentication, exchange).switchIfEmpty(Mono.defer(() -> Mono
.error(new ClientAuthorizationRequiredException(
registrationId))));
}).cast(OAuth2AuthorizedClient.class);
return authorizedClient.cast(Object.class);
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange)
.flatMap(this.authorizedClientResolver::loadAuthorizedClient);
});
}
private Mono<String> clientRegistrationId() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.filter(authentication -> authentication instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
}
}

View File

@ -0,0 +1,186 @@
/*
* Copyright 2002-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.security.oauth2.client.web.reactive.result.method.annotation;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.authority.AuthorityUtils;
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
import org.springframework.util.Assert;
import org.springframework.web.server.ServerWebExchange;
import reactor.core.publisher.Mono;
import java.util.Optional;
/**
* @author Rob Winch
* @since 5.1
*/
class OAuth2AuthorizedClientResolver {
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
AuthorityUtils.createAuthorityList("ROLE_USER"));
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient();
private boolean defaultOAuth2AuthorizedClient;
public OAuth2AuthorizedClientResolver(
ReactiveClientRegistrationRepository clientRegistrationRepository,
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
this.clientRegistrationRepository = clientRegistrationRepository;
this.authorizedClientRepository = authorizedClientRepository;
}
/**
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
* resolved from the current Authentication.
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
* Default is false.
*/
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
}
/**
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
* client_credentials grant.
* @param clientCredentialsTokenResponseClient the client to use
*/
public void setClientCredentialsTokenResponseClient(
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
}
Mono<Request> createDefaultedRequest(String clientRegistrationId,
Authentication authentication, ServerWebExchange exchange) {
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
.switchIfEmpty(currentAuthentication());
Mono<String> defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId)
.switchIfEmpty(clientRegistrationId(defaultedAuthentication))
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one")));
Mono<Optional<ServerWebExchange>> defaultedExchange = Mono.justOrEmpty(exchange)
.switchIfEmpty(currentServerWebExchange()).map(Optional::of)
.defaultIfEmpty(Optional.empty());
return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
.map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
}
Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
String clientRegistrationId = request.getClientRegistrationId();
Authentication authentication = request.getAuthentication();
ServerWebExchange exchange = request.getExchange();
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
}
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
.flatMap(clientRegistration -> {
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
return clientCredentials(clientRegistration, authentication, exchange);
}
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
});
}
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse));
}
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) {
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
clientRegistration, authentication.getName(), tokenResponse.getAccessToken());
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange)
.thenReturn(authorizedClient);
}
/**
* Attempts to load the client registration id from the current {@link Authentication}
* @return
*/
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
return authentication
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
.cast(OAuth2AuthenticationToken.class)
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
}
private Mono<Authentication> currentAuthentication() {
return ReactiveSecurityContextHolder.getContext()
.map(SecurityContext::getAuthentication)
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
}
private Mono<ServerWebExchange> currentServerWebExchange() {
return Mono.subscriberContext()
.filter(c -> c.hasKey(ServerWebExchange.class))
.map(c -> c.get(ServerWebExchange.class));
}
static class Request {
private final String clientRegistrationId;
private final Authentication authentication;
private final ServerWebExchange exchange;
public Request(String clientRegistrationId, Authentication authentication,
ServerWebExchange exchange) {
this.clientRegistrationId = clientRegistrationId;
this.authentication = authentication;
this.exchange = exchange;
}
public String getClientRegistrationId() {
return this.clientRegistrationId;
}
public Authentication getAuthentication() {
return this.authentication;
}
public ServerWebExchange getExchange() {
return this.exchange;
}
}
}

View File

@ -16,6 +16,7 @@
package org.springframework.security.oauth2.client.web.reactive.function.client;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
@ -88,7 +89,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Mock
private ServerWebExchange serverWebExchange;
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
private ServerOAuth2AuthorizedClientExchangeFilterFunction function;
private MockExchangeFunction exchange = new MockExchangeFunction();
@ -100,6 +101,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
Instant.now(),
Instant.now().plus(Duration.ofDays(1)));
@Before
public void setup() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
}
@Test
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
@ -155,7 +161,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -204,7 +209,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
this.accessToken.getTokenValue(),
issuedAt,
accessTokenExpiresAt);
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
@ -236,8 +240,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
@ -258,8 +260,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenNotExpiredThenShouldRefreshFalse() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
@ -281,8 +281,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);
@ -306,7 +304,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
this.function.setDefaultOAuth2AuthorizedClient(true);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
@ -337,8 +334,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
.build();
@ -359,8 +354,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
@Test
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
"principalName", this.accessToken, refreshToken);

View File

@ -29,9 +29,10 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
import org.springframework.util.ReflectionUtils;
import reactor.core.publisher.Hooks;
import reactor.core.publisher.Mono;
import reactor.util.context.Context;
@ -50,6 +51,8 @@ import static org.mockito.Mockito.when;
*/
@RunWith(MockitoJUnitRunner.class)
public class OAuth2AuthorizedClientArgumentResolverTests {
@Mock
private ReactiveClientRegistrationRepository clientRegistrationRepository;
@Mock
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
@ -59,15 +62,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
@Before
public void setUp() {
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository);
this.authorizedClient = mock(OAuth2AuthorizedClient.class);
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient));
Hooks.onOperatorDebug();
}
@Test
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null))
.isInstanceOf(IllegalArgumentException.class);
}
@ -94,11 +96,13 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
assertThatThrownBy(() -> resolveArgument(methodParameter))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
.hasMessage("The clientRegistrationId could not be resolved. Please provide one");
}
@Test
public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
this.authentication = mock(OAuth2AuthenticationToken.class);
when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
@ -108,18 +112,24 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
@Test
public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() {
this.authentication = null;
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
}
@Test
public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
}
@Test
public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
TestClientRegistrations.clientRegistration().build()));
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty());
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
assertThatThrownBy(() -> resolveArgument(methodParameter))