OAuth2AuthorizedClientResolver
Extract out a private API for shared code between the argument resolver and WebClient support. This makes it easier to make changes in both locations. Later we will extract this out so it is not a copy/paste effort. Issue: gh-4921
This commit is contained in:
parent
23726abb1e
commit
438d2911fb
|
@ -21,6 +21,7 @@ import org.springframework.context.annotation.Configuration;
|
||||||
import org.springframework.context.annotation.ImportSelector;
|
import org.springframework.context.annotation.ImportSelector;
|
||||||
import org.springframework.core.type.AnnotationMetadata;
|
import org.springframework.core.type.AnnotationMetadata;
|
||||||
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
|
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
|
import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver;
|
||||||
|
@ -53,17 +54,25 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector {
|
||||||
|
|
||||||
@Configuration
|
@Configuration
|
||||||
static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
|
static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer {
|
||||||
|
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||||
|
|
||||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||||
|
|
||||||
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
|
private ReactiveOAuth2AuthorizedClientService authorizedClientService;
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
|
public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) {
|
||||||
if (this.authorizedClientRepository != null) {
|
if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) {
|
||||||
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(getAuthorizedClientRepository()));
|
configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Autowired(required = false)
|
||||||
|
public void setClientRegistrationRepository(
|
||||||
|
ReactiveClientRegistrationRepository clientRegistrationRepository) {
|
||||||
|
this.clientRegistrationRepository = clientRegistrationRepository;
|
||||||
|
}
|
||||||
|
|
||||||
@Autowired(required = false)
|
@Autowired(required = false)
|
||||||
public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||||
this.authorizedClientRepository = authorizedClientRepository;
|
this.authorizedClientRepository = authorizedClientRepository;
|
||||||
|
|
|
@ -0,0 +1,185 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2018 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||||
|
|
||||||
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||||
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.authority.AuthorityUtils;
|
||||||
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||||
|
import org.springframework.security.core.context.SecurityContext;
|
||||||
|
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
||||||
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||||
|
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Rob Winch
|
||||||
|
* @since 5.1
|
||||||
|
*/
|
||||||
|
class OAuth2AuthorizedClientResolver {
|
||||||
|
|
||||||
|
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
|
||||||
|
AuthorityUtils.createAuthorityList("ROLE_USER"));
|
||||||
|
|
||||||
|
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||||
|
|
||||||
|
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||||
|
|
||||||
|
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient();
|
||||||
|
|
||||||
|
private boolean defaultOAuth2AuthorizedClient;
|
||||||
|
|
||||||
|
public OAuth2AuthorizedClientResolver(
|
||||||
|
ReactiveClientRegistrationRepository clientRegistrationRepository,
|
||||||
|
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||||
|
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
|
||||||
|
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
|
||||||
|
this.clientRegistrationRepository = clientRegistrationRepository;
|
||||||
|
this.authorizedClientRepository = authorizedClientRepository;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
||||||
|
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
||||||
|
* resolved from the current Authentication.
|
||||||
|
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
|
||||||
|
* Default is false.
|
||||||
|
*/
|
||||||
|
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
||||||
|
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
||||||
|
* client_credentials grant.
|
||||||
|
* @param clientCredentialsTokenResponseClient the client to use
|
||||||
|
*/
|
||||||
|
public void setClientCredentialsTokenResponseClient(
|
||||||
|
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
||||||
|
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
|
||||||
|
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mono<Request> createDefaultedRequest(String clientRegistrationId,
|
||||||
|
Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
|
||||||
|
.switchIfEmpty(currentAuthentication());
|
||||||
|
|
||||||
|
Mono<String> defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId)
|
||||||
|
.switchIfEmpty(clientRegistrationId(defaultedAuthentication));
|
||||||
|
|
||||||
|
Mono<Optional<ServerWebExchange>> defaultedExchange = Mono.justOrEmpty(exchange)
|
||||||
|
.switchIfEmpty(currentServerWebExchange()).map(Optional::of)
|
||||||
|
.defaultIfEmpty(Optional.empty());
|
||||||
|
|
||||||
|
return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
|
||||||
|
.map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
|
||||||
|
String clientRegistrationId = request.getClientRegistrationId();
|
||||||
|
Authentication authentication = request.getAuthentication();
|
||||||
|
ServerWebExchange exchange = request.getExchange();
|
||||||
|
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
|
||||||
|
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||||
|
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
|
||||||
|
.flatMap(clientRegistration -> {
|
||||||
|
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
|
||||||
|
return clientCredentials(clientRegistration, authentication, exchange);
|
||||||
|
}
|
||||||
|
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
|
||||||
|
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
|
||||||
|
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
|
||||||
|
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) {
|
||||||
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||||
|
clientRegistration, authentication.getName(), tokenResponse.getAccessToken());
|
||||||
|
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange)
|
||||||
|
.thenReturn(authorizedClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempts to load the client registration id from the current {@link Authentication}
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
|
||||||
|
return authentication
|
||||||
|
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
|
||||||
|
.cast(OAuth2AuthenticationToken.class)
|
||||||
|
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<Authentication> currentAuthentication() {
|
||||||
|
return ReactiveSecurityContextHolder.getContext()
|
||||||
|
.map(SecurityContext::getAuthentication)
|
||||||
|
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<ServerWebExchange> currentServerWebExchange() {
|
||||||
|
return Mono.subscriberContext()
|
||||||
|
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||||
|
.map(c -> c.get(ServerWebExchange.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Request {
|
||||||
|
private final String clientRegistrationId;
|
||||||
|
private final Authentication authentication;
|
||||||
|
private final ServerWebExchange exchange;
|
||||||
|
|
||||||
|
public Request(String clientRegistrationId, Authentication authentication,
|
||||||
|
ServerWebExchange exchange) {
|
||||||
|
this.clientRegistrationId = clientRegistrationId;
|
||||||
|
this.authentication = authentication;
|
||||||
|
this.exchange = exchange;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getClientRegistrationId() {
|
||||||
|
return this.clientRegistrationId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Authentication getAuthentication() {
|
||||||
|
return this.authentication;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ServerWebExchange getExchange() {
|
||||||
|
return this.exchange;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,28 +16,21 @@
|
||||||
|
|
||||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||||
|
|
||||||
import com.sun.security.ntlm.Server;
|
|
||||||
import org.springframework.http.HttpHeaders;
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpMethod;
|
import org.springframework.http.HttpMethod;
|
||||||
import org.springframework.http.MediaType;
|
import org.springframework.http.MediaType;
|
||||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||||
import org.springframework.security.core.Authentication;
|
import org.springframework.security.core.Authentication;
|
||||||
import org.springframework.security.core.GrantedAuthority;
|
|
||||||
import org.springframework.security.core.authority.AuthorityUtils;
|
import org.springframework.security.core.authority.AuthorityUtils;
|
||||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||||
import org.springframework.security.core.context.SecurityContext;
|
|
||||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
||||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
|
||||||
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||||
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
||||||
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
|
||||||
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
|
||||||
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.web.reactive.function.BodyInserters;
|
import org.springframework.web.reactive.function.BodyInserters;
|
||||||
import org.springframework.web.reactive.function.client.ClientRequest;
|
import org.springframework.web.reactive.function.client.ClientRequest;
|
||||||
|
@ -51,9 +44,7 @@ import java.net.URI;
|
||||||
import java.time.Clock;
|
import java.time.Clock;
|
||||||
import java.time.Duration;
|
import java.time.Duration;
|
||||||
import java.time.Instant;
|
import java.time.Instant;
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
|
||||||
import java.util.function.Consumer;
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
|
import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse;
|
||||||
|
@ -88,20 +79,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
|
|
||||||
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
|
||||||
|
|
||||||
private boolean defaultOAuth2AuthorizedClient;
|
|
||||||
|
|
||||||
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient =
|
|
||||||
new WebClientReactiveClientCredentialsTokenResponseClient();
|
|
||||||
|
|
||||||
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
|
||||||
|
|
||||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||||
|
|
||||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction() {}
|
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
|
||||||
|
|
||||||
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||||
this.clientRegistrationRepository = clientRegistrationRepository;
|
|
||||||
this.authorizedClientRepository = authorizedClientRepository;
|
this.authorizedClientRepository = authorizedClientRepository;
|
||||||
|
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -142,6 +126,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
|
return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) {
|
||||||
|
return (OAuth2AuthorizedClient) request.attributes().get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
|
* Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for
|
||||||
|
@ -166,6 +153,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
|
return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static ServerWebExchange serverWebExchange(ClientRequest request) {
|
||||||
|
return (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
|
* Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to
|
||||||
* be used to look up the {@link OAuth2AuthorizedClient}.
|
* be used to look up the {@link OAuth2AuthorizedClient}.
|
||||||
|
@ -178,6 +169,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private static String clientRegistrationId(ClientRequest request) {
|
||||||
|
OAuth2AuthorizedClient authorizedClient = oauth2AuthorizedClient(request);
|
||||||
|
if (authorizedClient != null) {
|
||||||
|
return authorizedClient.getClientRegistration().getRegistrationId();
|
||||||
|
}
|
||||||
|
return (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
||||||
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
||||||
|
@ -186,7 +185,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
* Default is false.
|
* Default is false.
|
||||||
*/
|
*/
|
||||||
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
||||||
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
|
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -196,8 +195,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
*/
|
*/
|
||||||
public void setClientCredentialsTokenResponseClient(
|
public void setClientCredentialsTokenResponseClient(
|
||||||
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
||||||
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
|
this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient);
|
||||||
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -212,128 +210,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
public Mono<ClientResponse> filter(ClientRequest request, ExchangeFunction next) {
|
||||||
return authorizedClient(request)
|
return authorizedClient(request, next)
|
||||||
.flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request))
|
|
||||||
.map(authorizedClient -> bearer(request, authorizedClient))
|
.map(authorizedClient -> bearer(request, authorizedClient))
|
||||||
.flatMap(next::exchange)
|
.flatMap(next::exchange)
|
||||||
.switchIfEmpty(next.exchange(request));
|
.switchIfEmpty(next.exchange(request));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<ServerWebExchange> serverWebExchange(ClientRequest request) {
|
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request, ExchangeFunction next) {
|
||||||
ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME);
|
OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request);
|
||||||
return Mono.justOrEmpty(exchange)
|
return Mono.justOrEmpty(authorizedClientFromAttrs)
|
||||||
.switchIfEmpty(serverWebExchange());
|
.switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request)))
|
||||||
|
.flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<ServerWebExchange> serverWebExchange() {
|
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(ClientRequest request) {
|
||||||
return Mono.subscriberContext()
|
return createRequest(request)
|
||||||
.filter(c -> c.hasKey(ServerWebExchange.class))
|
.flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r));
|
||||||
.map(c -> c.get(ServerWebExchange.class));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest request) {
|
private Mono<OAuth2AuthorizedClientResolver.Request> createRequest(ClientRequest request) {
|
||||||
Optional<OAuth2AuthorizedClient> attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)
|
String clientRegistrationId = clientRegistrationId(request);
|
||||||
.map(OAuth2AuthorizedClient.class::cast);
|
Authentication authentication = null;
|
||||||
return Mono.justOrEmpty(attribute)
|
ServerWebExchange exchange = serverWebExchange(request);
|
||||||
.switchIfEmpty(findAuthorizedClientByRegistrationId(request));
|
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> findAuthorizedClientByRegistrationId(ClientRequest request) {
|
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) {
|
||||||
if (this.authorizedClientRepository == null) {
|
|
||||||
return Mono.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
return currentAuthentication()
|
|
||||||
.flatMap(principal -> clientRegistrationId(request, principal)
|
|
||||||
.flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal)))
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<String> clientRegistrationId(ClientRequest request, Authentication authentication) {
|
|
||||||
return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME))
|
|
||||||
.cast(String.class)
|
|
||||||
.switchIfEmpty(clientRegistrationId(authentication));
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<String> clientRegistrationId(Authentication authentication) {
|
|
||||||
return Mono.justOrEmpty(authentication)
|
|
||||||
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
|
|
||||||
.cast(OAuth2AuthenticationToken.class)
|
|
||||||
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> loadAuthorizedClient(String clientRegistrationId,
|
|
||||||
ServerWebExchange exchange, Authentication principal) {
|
|
||||||
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange)
|
|
||||||
.switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange));
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) {
|
|
||||||
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
|
||||||
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
|
|
||||||
.flatMap(clientRegistration -> {
|
|
||||||
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
|
|
||||||
return clientCredentials(clientRegistration, exchange);
|
|
||||||
}
|
|
||||||
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
|
|
||||||
ClientRegistration clientRegistration, ServerWebExchange exchange) {
|
|
||||||
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
|
|
||||||
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
|
|
||||||
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange));
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) {
|
|
||||||
return currentAuthentication()
|
|
||||||
.flatMap(principal -> {
|
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
|
||||||
clientRegistration, (principal != null ?
|
|
||||||
principal.getName() :
|
|
||||||
"anonymousUser"),
|
|
||||||
tokenResponse.getAccessToken());
|
|
||||||
|
|
||||||
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null)
|
|
||||||
.thenReturn(authorizedClient);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) {
|
|
||||||
if (shouldRefresh(authorizedClient)) {
|
if (shouldRefresh(authorizedClient)) {
|
||||||
return serverWebExchange(request)
|
return createRequest(request)
|
||||||
.flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange));
|
.flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r));
|
||||||
}
|
}
|
||||||
return Mono.just(authorizedClient);
|
return Mono.just(authorizedClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
|
private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ExchangeFunction next,
|
||||||
OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) {
|
OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) {
|
||||||
|
ServerWebExchange exchange = r.getExchange();
|
||||||
|
Authentication authentication = r.getAuthentication();
|
||||||
ClientRegistration clientRegistration = authorizedClient
|
ClientRegistration clientRegistration = authorizedClient
|
||||||
.getClientRegistration();
|
.getClientRegistration();
|
||||||
String tokenUri = clientRegistration
|
String tokenUri = clientRegistration
|
||||||
.getProviderDetails().getTokenUri();
|
.getProviderDetails().getTokenUri();
|
||||||
ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
|
ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri))
|
||||||
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
.header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE)
|
||||||
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
|
.headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret()))
|
||||||
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
|
.body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue()))
|
||||||
.build();
|
.build();
|
||||||
return next.exchange(request)
|
return next.exchange(refreshRequest)
|
||||||
.flatMap(response -> response.body(oauth2AccessTokenResponse()))
|
.flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse()))
|
||||||
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
|
.map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken()))
|
||||||
.flatMap(result -> currentAuthentication()
|
.flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange)
|
||||||
.defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName()))
|
|
||||||
.flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange))
|
|
||||||
.thenReturn(result));
|
.thenReturn(result));
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<Authentication> currentAuthentication() {
|
|
||||||
return ReactiveSecurityContextHolder.getContext()
|
|
||||||
.map(SecurityContext::getAuthentication)
|
|
||||||
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
|
private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) {
|
||||||
if (this.authorizedClientRepository == null) {
|
if (this.authorizedClientRepository == null) {
|
||||||
return false;
|
return false;
|
||||||
|
@ -361,52 +290,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements
|
||||||
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
|
.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue())
|
||||||
.with("refresh_token", refreshToken);
|
.with("refresh_token", refreshToken);
|
||||||
}
|
}
|
||||||
|
|
||||||
private static class PrincipalNameAuthentication implements Authentication {
|
|
||||||
private final String username;
|
|
||||||
|
|
||||||
private PrincipalNameAuthentication(String username) {
|
|
||||||
this.username = username;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Collection<? extends GrantedAuthority> getAuthorities() {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object getCredentials() {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object getDetails() {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object getPrincipal() {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isAuthenticated() {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setAuthenticated(boolean isAuthenticated)
|
|
||||||
throws IllegalArgumentException {
|
|
||||||
throw unsupported();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getName() {
|
|
||||||
return this.username;
|
|
||||||
}
|
|
||||||
|
|
||||||
private UnsupportedOperationException unsupported() {
|
|
||||||
return new UnsupportedOperationException("Not Supported");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,15 +18,9 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an
|
||||||
|
|
||||||
import org.springframework.core.MethodParameter;
|
import org.springframework.core.MethodParameter;
|
||||||
import org.springframework.core.annotation.AnnotatedElementUtils;
|
import org.springframework.core.annotation.AnnotatedElementUtils;
|
||||||
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
|
||||||
import org.springframework.security.core.Authentication;
|
|
||||||
import org.springframework.security.core.authority.AuthorityUtils;
|
|
||||||
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
|
||||||
import org.springframework.security.core.context.SecurityContext;
|
|
||||||
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
|
||||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||||
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
|
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
|
||||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.util.Assert;
|
import org.springframework.util.Assert;
|
||||||
import org.springframework.util.StringUtils;
|
import org.springframework.util.StringUtils;
|
||||||
|
@ -56,16 +50,18 @@ import reactor.core.publisher.Mono;
|
||||||
* @see RegisteredOAuth2AuthorizedClient
|
* @see RegisteredOAuth2AuthorizedClient
|
||||||
*/
|
*/
|
||||||
public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
|
public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver {
|
||||||
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
|
||||||
|
private final OAuth2AuthorizedClientResolver authorizedClientResolver;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
|
* Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters.
|
||||||
*
|
*
|
||||||
* @param authorizedClientRepository the authorized client repository
|
* @param authorizedClientRepository the authorized client repository
|
||||||
*/
|
*/
|
||||||
public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||||
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
|
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
|
||||||
this.authorizedClientRepository = authorizedClientRepository;
|
this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository);
|
||||||
|
this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
|
@ -80,41 +76,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth
|
||||||
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
|
RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils
|
||||||
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
|
.findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class);
|
||||||
|
|
||||||
Mono<String> clientRegistrationId = Mono.justOrEmpty(authorizedClientAnnotation.registrationId())
|
String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ?
|
||||||
.filter(id -> !StringUtils.isEmpty(id))
|
authorizedClientAnnotation.registrationId() : null;
|
||||||
.switchIfEmpty(clientRegistrationId())
|
|
||||||
.switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException(
|
|
||||||
"Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."))));
|
|
||||||
|
|
||||||
Mono<Authentication> principal = ReactiveSecurityContextHolder.getContext()
|
return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange)
|
||||||
.map(SecurityContext::getAuthentication)
|
.flatMap(this.authorizedClientResolver::loadAuthorizedClient);
|
||||||
.defaultIfEmpty(new AnonymousAuthenticationToken("key", "anonymous",
|
|
||||||
AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS")));
|
|
||||||
|
|
||||||
Mono<OAuth2AuthorizedClient> authorizedClient = Mono
|
|
||||||
.zip(clientRegistrationId, principal).switchIfEmpty(
|
|
||||||
clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException(
|
|
||||||
"Unable to resolve the Authorized Client with registration identifier \""
|
|
||||||
+ id
|
|
||||||
+ "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured."))))
|
|
||||||
.flatMap(zipped -> {
|
|
||||||
String registrationId = zipped.getT1();
|
|
||||||
Authentication authentication = zipped.getT2();
|
|
||||||
return this.authorizedClientRepository
|
|
||||||
.loadAuthorizedClient(registrationId, authentication, exchange).switchIfEmpty(Mono.defer(() -> Mono
|
|
||||||
.error(new ClientAuthorizationRequiredException(
|
|
||||||
registrationId))));
|
|
||||||
}).cast(OAuth2AuthorizedClient.class);
|
|
||||||
|
|
||||||
return authorizedClient.cast(Object.class);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
private Mono<String> clientRegistrationId() {
|
|
||||||
return ReactiveSecurityContextHolder.getContext()
|
|
||||||
.map(SecurityContext::getAuthentication)
|
|
||||||
.filter(authentication -> authentication instanceof OAuth2AuthenticationToken)
|
|
||||||
.cast(OAuth2AuthenticationToken.class)
|
|
||||||
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,186 @@
|
||||||
|
/*
|
||||||
|
* Copyright 2002-2018 the original author or authors.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.springframework.security.oauth2.client.web.reactive.result.method.annotation;
|
||||||
|
|
||||||
|
import org.springframework.security.authentication.AnonymousAuthenticationToken;
|
||||||
|
import org.springframework.security.core.Authentication;
|
||||||
|
import org.springframework.security.core.authority.AuthorityUtils;
|
||||||
|
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
|
||||||
|
import org.springframework.security.core.context.SecurityContext;
|
||||||
|
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
|
||||||
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||||
|
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient;
|
||||||
|
import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ClientRegistration;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
|
import org.springframework.security.oauth2.core.AuthorizationGrantType;
|
||||||
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
|
||||||
|
import org.springframework.util.Assert;
|
||||||
|
import org.springframework.web.server.ServerWebExchange;
|
||||||
|
import reactor.core.publisher.Mono;
|
||||||
|
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @author Rob Winch
|
||||||
|
* @since 5.1
|
||||||
|
*/
|
||||||
|
class OAuth2AuthorizedClientResolver {
|
||||||
|
|
||||||
|
private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser",
|
||||||
|
AuthorityUtils.createAuthorityList("ROLE_USER"));
|
||||||
|
|
||||||
|
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||||
|
|
||||||
|
private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||||
|
|
||||||
|
private ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient();
|
||||||
|
|
||||||
|
private boolean defaultOAuth2AuthorizedClient;
|
||||||
|
|
||||||
|
public OAuth2AuthorizedClientResolver(
|
||||||
|
ReactiveClientRegistrationRepository clientRegistrationRepository,
|
||||||
|
ServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
|
||||||
|
Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null");
|
||||||
|
Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null");
|
||||||
|
this.clientRegistrationRepository = clientRegistrationRepository;
|
||||||
|
this.authorizedClientRepository = authorizedClientRepository;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is
|
||||||
|
* recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be
|
||||||
|
* resolved from the current Authentication.
|
||||||
|
* @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false.
|
||||||
|
* Default is false.
|
||||||
|
*/
|
||||||
|
public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) {
|
||||||
|
this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for
|
||||||
|
* client_credentials grant.
|
||||||
|
* @param clientCredentialsTokenResponseClient the client to use
|
||||||
|
*/
|
||||||
|
public void setClientCredentialsTokenResponseClient(
|
||||||
|
ReactiveOAuth2AccessTokenResponseClient<OAuth2ClientCredentialsGrantRequest> clientCredentialsTokenResponseClient) {
|
||||||
|
Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null");
|
||||||
|
this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient;
|
||||||
|
}
|
||||||
|
|
||||||
|
Mono<Request> createDefaultedRequest(String clientRegistrationId,
|
||||||
|
Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
Mono<Authentication> defaultedAuthentication = Mono.justOrEmpty(authentication)
|
||||||
|
.switchIfEmpty(currentAuthentication());
|
||||||
|
|
||||||
|
Mono<String> defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId)
|
||||||
|
.switchIfEmpty(clientRegistrationId(defaultedAuthentication))
|
||||||
|
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one")));
|
||||||
|
|
||||||
|
Mono<Optional<ServerWebExchange>> defaultedExchange = Mono.justOrEmpty(exchange)
|
||||||
|
.switchIfEmpty(currentServerWebExchange()).map(Optional::of)
|
||||||
|
.defaultIfEmpty(Optional.empty());
|
||||||
|
|
||||||
|
return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange)
|
||||||
|
.map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null)));
|
||||||
|
}
|
||||||
|
|
||||||
|
Mono<OAuth2AuthorizedClient> loadAuthorizedClient(Request request) {
|
||||||
|
String clientRegistrationId = request.getClientRegistrationId();
|
||||||
|
Authentication authentication = request.getAuthentication();
|
||||||
|
ServerWebExchange exchange = request.getExchange();
|
||||||
|
return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange)
|
||||||
|
.switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<OAuth2AuthorizedClient> authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
|
||||||
|
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found")))
|
||||||
|
.flatMap(clientRegistration -> {
|
||||||
|
if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) {
|
||||||
|
return clientCredentials(clientRegistration, authentication, exchange);
|
||||||
|
}
|
||||||
|
return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<? extends OAuth2AuthorizedClient> clientCredentials(
|
||||||
|
ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) {
|
||||||
|
OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration);
|
||||||
|
return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest)
|
||||||
|
.flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse));
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<OAuth2AuthorizedClient> clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) {
|
||||||
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(
|
||||||
|
clientRegistration, authentication.getName(), tokenResponse.getAccessToken());
|
||||||
|
return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange)
|
||||||
|
.thenReturn(authorizedClient);
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Attempts to load the client registration id from the current {@link Authentication}
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private Mono<String> clientRegistrationId(Mono<Authentication> authentication) {
|
||||||
|
return authentication
|
||||||
|
.filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken)
|
||||||
|
.cast(OAuth2AuthenticationToken.class)
|
||||||
|
.map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<Authentication> currentAuthentication() {
|
||||||
|
return ReactiveSecurityContextHolder.getContext()
|
||||||
|
.map(SecurityContext::getAuthentication)
|
||||||
|
.defaultIfEmpty(ANONYMOUS_USER_TOKEN);
|
||||||
|
}
|
||||||
|
|
||||||
|
private Mono<ServerWebExchange> currentServerWebExchange() {
|
||||||
|
return Mono.subscriberContext()
|
||||||
|
.filter(c -> c.hasKey(ServerWebExchange.class))
|
||||||
|
.map(c -> c.get(ServerWebExchange.class));
|
||||||
|
}
|
||||||
|
|
||||||
|
static class Request {
|
||||||
|
private final String clientRegistrationId;
|
||||||
|
private final Authentication authentication;
|
||||||
|
private final ServerWebExchange exchange;
|
||||||
|
|
||||||
|
public Request(String clientRegistrationId, Authentication authentication,
|
||||||
|
ServerWebExchange exchange) {
|
||||||
|
this.clientRegistrationId = clientRegistrationId;
|
||||||
|
this.authentication = authentication;
|
||||||
|
this.exchange = exchange;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getClientRegistrationId() {
|
||||||
|
return this.clientRegistrationId;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Authentication getAuthentication() {
|
||||||
|
return this.authentication;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ServerWebExchange getExchange() {
|
||||||
|
return this.exchange;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
package org.springframework.security.oauth2.client.web.reactive.function.client;
|
||||||
|
|
||||||
|
import org.junit.Before;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.junit.runner.RunWith;
|
import org.junit.runner.RunWith;
|
||||||
import org.mockito.Mock;
|
import org.mockito.Mock;
|
||||||
|
@ -88,7 +89,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
@Mock
|
@Mock
|
||||||
private ServerWebExchange serverWebExchange;
|
private ServerWebExchange serverWebExchange;
|
||||||
|
|
||||||
private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction();
|
private ServerOAuth2AuthorizedClientExchangeFilterFunction function;
|
||||||
|
|
||||||
private MockExchangeFunction exchange = new MockExchangeFunction();
|
private MockExchangeFunction exchange = new MockExchangeFunction();
|
||||||
|
|
||||||
|
@ -100,6 +101,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
Instant.now(),
|
Instant.now(),
|
||||||
Instant.now().plus(Duration.ofDays(1)));
|
Instant.now().plus(Duration.ofDays(1)));
|
||||||
|
|
||||||
|
@Before
|
||||||
|
public void setup() {
|
||||||
|
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
|
public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() {
|
||||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||||
|
@ -155,7 +161,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
this.accessToken.getTokenValue(),
|
this.accessToken.getTokenValue(),
|
||||||
issuedAt,
|
issuedAt,
|
||||||
accessTokenExpiresAt);
|
accessTokenExpiresAt);
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
|
@ -204,7 +209,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
this.accessToken.getTokenValue(),
|
this.accessToken.getTokenValue(),
|
||||||
issuedAt,
|
issuedAt,
|
||||||
accessTokenExpiresAt);
|
accessTokenExpiresAt);
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt);
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
|
@ -236,8 +240,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
|
public void filterWhenRefreshTokenNullThenShouldRefreshFalse() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
"principalName", this.accessToken);
|
"principalName", this.accessToken);
|
||||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||||
|
@ -258,8 +260,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenNotExpiredThenShouldRefreshFalse() {
|
public void filterWhenNotExpiredThenShouldRefreshFalse() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
"principalName", this.accessToken, refreshToken);
|
"principalName", this.accessToken, refreshToken);
|
||||||
|
@ -281,8 +281,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
|
public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
"principalName", this.accessToken, refreshToken);
|
"principalName", this.accessToken, refreshToken);
|
||||||
|
@ -306,7 +304,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
|
public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
this.function.setDefaultOAuth2AuthorizedClient(true);
|
this.function.setDefaultOAuth2AuthorizedClient(true);
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||||
|
@ -337,8 +334,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
|
public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
@ -359,8 +354,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
|
public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() {
|
||||||
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository);
|
|
||||||
|
|
||||||
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt());
|
||||||
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration,
|
||||||
"principalName", this.accessToken, refreshToken);
|
"principalName", this.accessToken, refreshToken);
|
||||||
|
|
|
@ -29,9 +29,10 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc
|
||||||
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
|
||||||
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
|
import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient;
|
||||||
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
|
||||||
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
|
||||||
|
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
|
||||||
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
|
||||||
import org.springframework.util.ReflectionUtils;
|
import org.springframework.util.ReflectionUtils;
|
||||||
import reactor.core.publisher.Hooks;
|
|
||||||
import reactor.core.publisher.Mono;
|
import reactor.core.publisher.Mono;
|
||||||
import reactor.util.context.Context;
|
import reactor.util.context.Context;
|
||||||
|
|
||||||
|
@ -50,6 +51,8 @@ import static org.mockito.Mockito.when;
|
||||||
*/
|
*/
|
||||||
@RunWith(MockitoJUnitRunner.class)
|
@RunWith(MockitoJUnitRunner.class)
|
||||||
public class OAuth2AuthorizedClientArgumentResolverTests {
|
public class OAuth2AuthorizedClientArgumentResolverTests {
|
||||||
|
@Mock
|
||||||
|
private ReactiveClientRegistrationRepository clientRegistrationRepository;
|
||||||
@Mock
|
@Mock
|
||||||
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
private ServerOAuth2AuthorizedClientRepository authorizedClientRepository;
|
||||||
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
|
private OAuth2AuthorizedClientArgumentResolver argumentResolver;
|
||||||
|
@ -59,15 +62,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
|
||||||
|
|
||||||
@Before
|
@Before
|
||||||
public void setUp() {
|
public void setUp() {
|
||||||
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository);
|
this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository);
|
||||||
this.authorizedClient = mock(OAuth2AuthorizedClient.class);
|
this.authorizedClient = mock(OAuth2AuthorizedClient.class);
|
||||||
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient));
|
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient));
|
||||||
Hooks.onOperatorDebug();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
|
public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() {
|
||||||
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null))
|
assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null))
|
||||||
.isInstanceOf(IllegalArgumentException.class);
|
.isInstanceOf(IllegalArgumentException.class);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -94,11 +96,13 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
|
||||||
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
|
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
|
||||||
assertThatThrownBy(() -> resolveArgument(methodParameter))
|
assertThatThrownBy(() -> resolveArgument(methodParameter))
|
||||||
.isInstanceOf(IllegalArgumentException.class)
|
.isInstanceOf(IllegalArgumentException.class)
|
||||||
.hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").");
|
.hasMessage("The clientRegistrationId could not be resolved. Please provide one");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
|
public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() {
|
||||||
|
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
|
||||||
|
TestClientRegistrations.clientRegistration().build()));
|
||||||
this.authentication = mock(OAuth2AuthenticationToken.class);
|
this.authentication = mock(OAuth2AuthenticationToken.class);
|
||||||
when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
|
when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1");
|
||||||
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
|
MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class);
|
||||||
|
@ -108,18 +112,24 @@ public class OAuth2AuthorizedClientArgumentResolverTests {
|
||||||
@Test
|
@Test
|
||||||
public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() {
|
public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() {
|
||||||
this.authentication = null;
|
this.authentication = null;
|
||||||
|
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
|
||||||
|
TestClientRegistrations.clientRegistration().build()));
|
||||||
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
||||||
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
|
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() {
|
public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() {
|
||||||
|
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
|
||||||
|
TestClientRegistrations.clientRegistration().build()));
|
||||||
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
||||||
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
|
assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
|
public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() {
|
||||||
|
when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just(
|
||||||
|
TestClientRegistrations.clientRegistration().build()));
|
||||||
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty());
|
when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty());
|
||||||
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class);
|
||||||
assertThatThrownBy(() -> resolveArgument(methodParameter))
|
assertThatThrownBy(() -> resolveArgument(methodParameter))
|
||||||
|
|
Loading…
Reference in New Issue