From 438d2911fbcd5bbfe13e8c1bfa66c78b94ea97e4 Mon Sep 17 00:00:00 2001 From: Rob Winch Date: Wed, 5 Sep 2018 20:38:12 -0500 Subject: [PATCH] OAuth2AuthorizedClientResolver Extract out a private API for shared code between the argument resolver and WebClient support. This makes it easier to make changes in both locations. Later we will extract this out so it is not a copy/paste effort. Issue: gh-4921 --- .../ReactiveOAuth2ClientImportSelector.java | 13 +- .../OAuth2AuthorizedClientResolver.java | 185 ++++++++++++++++ ...uthorizedClientExchangeFilterFunction.java | 205 ++++-------------- ...Auth2AuthorizedClientArgumentResolver.java | 54 +---- .../OAuth2AuthorizedClientResolver.java | 186 ++++++++++++++++ ...izedClientExchangeFilterFunctionTests.java | 21 +- ...AuthorizedClientArgumentResolverTests.java | 20 +- 7 files changed, 457 insertions(+), 227 deletions(-) create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java create mode 100644 oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java index 760e3aecf0..a089518a56 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/reactive/ReactiveOAuth2ClientImportSelector.java @@ -21,6 +21,7 @@ import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.ImportSelector; import org.springframework.core.type.AnnotationMetadata; import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientService; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.AuthenticatedPrincipalServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.client.web.reactive.result.method.annotation.OAuth2AuthorizedClientArgumentResolver; @@ -53,17 +54,25 @@ final class ReactiveOAuth2ClientImportSelector implements ImportSelector { @Configuration static class OAuth2ClientWebFluxSecurityConfiguration implements WebFluxConfigurer { + private ReactiveClientRegistrationRepository clientRegistrationRepository; + private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private ReactiveOAuth2AuthorizedClientService authorizedClientService; @Override public void configureArgumentResolvers(ArgumentResolverConfigurer configurer) { - if (this.authorizedClientRepository != null) { - configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(getAuthorizedClientRepository())); + if (this.authorizedClientRepository != null && this.clientRegistrationRepository != null) { + configurer.addCustomResolver(new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, getAuthorizedClientRepository())); } } + @Autowired(required = false) + public void setClientRegistrationRepository( + ReactiveClientRegistrationRepository clientRegistrationRepository) { + this.clientRegistrationRepository = clientRegistrationRepository; + } + @Autowired(required = false) public void setAuthorizedClientRepository(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { this.authorizedClientRepository = authorizedClientRepository; diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java new file mode 100644 index 0000000000..6b381c6d2b --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/OAuth2AuthorizedClientResolver.java @@ -0,0 +1,185 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.function.client; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Optional; + +/** + * @author Rob Winch + * @since 5.1 + */ +class OAuth2AuthorizedClientResolver { + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + + private boolean defaultOAuth2AuthorizedClient; + + public OAuth2AuthorizedClientResolver( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is + * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be + * resolved from the current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. + * Default is false. + */ + public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { + this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + } + + /** + * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + + Mono createDefaultedRequest(String clientRegistrationId, + Authentication authentication, ServerWebExchange exchange) { + Mono defaultedAuthentication = Mono.justOrEmpty(authentication) + .switchIfEmpty(currentAuthentication()); + + Mono defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId) + .switchIfEmpty(clientRegistrationId(defaultedAuthentication)); + + Mono> defaultedExchange = Mono.justOrEmpty(exchange) + .switchIfEmpty(currentServerWebExchange()).map(Optional::of) + .defaultIfEmpty(Optional.empty()); + + return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) + .map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null))); + } + + Mono loadAuthorizedClient(Request request) { + String clientRegistrationId = request.getClientRegistrationId(); + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange) + .switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange)); + } + + private Mono authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) + .flatMap(clientRegistration -> { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return clientCredentials(clientRegistration, authentication, exchange); + } + return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); + }); +} + + private Mono clientCredentials( + ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) + .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse)); + } + + private Mono clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, authentication.getName(), tokenResponse.getAccessToken()); + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange) + .thenReturn(authorizedClient); + } + + /** + * Attempts to load the client registration id from the current {@link Authentication} + * @return + */ + private Mono clientRegistrationId(Mono authentication) { + return authentication + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } + + static class Request { + private final String clientRegistrationId; + private final Authentication authentication; + private final ServerWebExchange exchange; + + public Request(String clientRegistrationId, Authentication authentication, + ServerWebExchange exchange) { + this.clientRegistrationId = clientRegistrationId; + this.authentication = authentication; + this.exchange = exchange; + } + + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public ServerWebExchange getExchange() { + return this.exchange; + } + } +} diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java index ce915f120a..3d66f295a8 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java @@ -16,28 +16,21 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; -import com.sun.security.ntlm.Server; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.core.Authentication; -import org.springframework.security.core.GrantedAuthority; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; -import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; import org.springframework.security.oauth2.client.registration.ClientRegistration; import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.security.oauth2.core.AuthorizationGrantType; import org.springframework.security.oauth2.core.OAuth2RefreshToken; -import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; import org.springframework.util.Assert; import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.reactive.function.client.ClientRequest; @@ -51,9 +44,7 @@ import java.net.URI; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.Collection; import java.util.Map; -import java.util.Optional; import java.util.function.Consumer; import static org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors.oauth2AccessTokenResponse; @@ -88,20 +79,13 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements private Duration accessTokenExpiresSkew = Duration.ofMinutes(1); - private boolean defaultOAuth2AuthorizedClient; - - private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = - new WebClientReactiveClientCredentialsTokenResponseClient(); - - private ReactiveClientRegistrationRepository clientRegistrationRepository; - private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; - public ServerOAuth2AuthorizedClientExchangeFilterFunction() {} + private final OAuth2AuthorizedClientResolver authorizedClientResolver; public ServerOAuth2AuthorizedClientExchangeFilterFunction(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { - this.clientRegistrationRepository = clientRegistrationRepository; this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); } /** @@ -142,6 +126,9 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, authorizedClient); } + private static OAuth2AuthorizedClient oauth2AuthorizedClient(ClientRequest request) { + return (OAuth2AuthorizedClient) request.attributes().get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME); + } /** * Modifies the {@link ClientRequest#attributes()} to include the {@link OAuth2AuthorizedClient} to be used for @@ -166,6 +153,10 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(SERVER_WEB_EXCHANGE_ATTR_NAME, serverWebExchange); } + private static ServerWebExchange serverWebExchange(ClientRequest request) { + return (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); + } + /** * Modifies the {@link ClientRequest#attributes()} to include the {@link ClientRegistration#getRegistrationId()} to * be used to look up the {@link OAuth2AuthorizedClient}. @@ -178,6 +169,14 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements return attributes -> attributes.put(CLIENT_REGISTRATION_ID_ATTR_NAME, clientRegistrationId); } + private static String clientRegistrationId(ClientRequest request) { + OAuth2AuthorizedClient authorizedClient = oauth2AuthorizedClient(request); + if (authorizedClient != null) { + return authorizedClient.getClientRegistration().getRegistrationId(); + } + return (String) request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME); + } + /** * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be @@ -186,7 +185,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements * Default is false. */ public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { - this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(defaultOAuth2AuthorizedClient); } /** @@ -196,8 +195,7 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements */ public void setClientCredentialsTokenResponseClient( ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { - Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); - this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + this.authorizedClientResolver.setClientCredentialsTokenResponseClient(clientCredentialsTokenResponseClient); } /** @@ -212,128 +210,59 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements @Override public Mono filter(ClientRequest request, ExchangeFunction next) { - return authorizedClient(request) - .flatMap(authorizedClient -> refreshIfNecessary(next, authorizedClient, request)) + return authorizedClient(request, next) .map(authorizedClient -> bearer(request, authorizedClient)) .flatMap(next::exchange) .switchIfEmpty(next.exchange(request)); } - private Mono serverWebExchange(ClientRequest request) { - ServerWebExchange exchange = (ServerWebExchange) request.attributes().get(SERVER_WEB_EXCHANGE_ATTR_NAME); - return Mono.justOrEmpty(exchange) - .switchIfEmpty(serverWebExchange()); + private Mono authorizedClient(ClientRequest request, ExchangeFunction next) { + OAuth2AuthorizedClient authorizedClientFromAttrs = oauth2AuthorizedClient(request); + return Mono.justOrEmpty(authorizedClientFromAttrs) + .switchIfEmpty(Mono.defer(() -> loadAuthorizedClient(request))) + .flatMap(authorizedClient -> refreshIfNecessary(request, next, authorizedClient)); } - private Mono serverWebExchange() { - return Mono.subscriberContext() - .filter(c -> c.hasKey(ServerWebExchange.class)) - .map(c -> c.get(ServerWebExchange.class)); + private Mono loadAuthorizedClient(ClientRequest request) { + return createRequest(request) + .flatMap(r -> this.authorizedClientResolver.loadAuthorizedClient(r)); } - private Mono authorizedClient(ClientRequest request) { - Optional attribute = request.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME) - .map(OAuth2AuthorizedClient.class::cast); - return Mono.justOrEmpty(attribute) - .switchIfEmpty(findAuthorizedClientByRegistrationId(request)); + private Mono createRequest(ClientRequest request) { + String clientRegistrationId = clientRegistrationId(request); + Authentication authentication = null; + ServerWebExchange exchange = serverWebExchange(request); + return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, authentication, exchange); } - private Mono findAuthorizedClientByRegistrationId(ClientRequest request) { - if (this.authorizedClientRepository == null) { - return Mono.empty(); - } - - return currentAuthentication() - .flatMap(principal -> clientRegistrationId(request, principal) - .flatMap(clientRegistrationId -> serverWebExchange(request).flatMap(exchange -> loadAuthorizedClient(clientRegistrationId, exchange, principal))) - ); - } - - private Mono clientRegistrationId(ClientRequest request, Authentication authentication) { - return Mono.justOrEmpty(request.attributes().get(CLIENT_REGISTRATION_ID_ATTR_NAME)) - .cast(String.class) - .switchIfEmpty(clientRegistrationId(authentication)); - } - - private Mono clientRegistrationId(Authentication authentication) { - return Mono.justOrEmpty(authentication) - .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) - .cast(OAuth2AuthenticationToken.class) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); - } - - private Mono loadAuthorizedClient(String clientRegistrationId, - ServerWebExchange exchange, Authentication principal) { - return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, exchange) - .switchIfEmpty(authorizedClientNotFound(clientRegistrationId, exchange)); - } - - private Mono authorizedClientNotFound(String clientRegistrationId, ServerWebExchange exchange) { - return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) - .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) - .flatMap(clientRegistration -> { - if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { - return clientCredentials(clientRegistration, exchange); - } - return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); - }); - } - - private Mono clientCredentials( - ClientRegistration clientRegistration, ServerWebExchange exchange) { - OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); - return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) - .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, tokenResponse, exchange)); - } - - private Mono clientCredentialsResponse(ClientRegistration clientRegistration, OAuth2AccessTokenResponse tokenResponse, ServerWebExchange exchange) { - return currentAuthentication() - .flatMap(principal -> { - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( - clientRegistration, (principal != null ? - principal.getName() : - "anonymousUser"), - tokenResponse.getAccessToken()); - - return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null) - .thenReturn(authorizedClient); - }); - } - - private Mono refreshIfNecessary(ExchangeFunction next, OAuth2AuthorizedClient authorizedClient, ClientRequest request) { + private Mono refreshIfNecessary(ClientRequest request, ExchangeFunction next, OAuth2AuthorizedClient authorizedClient) { if (shouldRefresh(authorizedClient)) { - return serverWebExchange(request) - .flatMap(exchange -> refreshAuthorizedClient(next, authorizedClient, exchange)); + return createRequest(request) + .flatMap(r -> refreshAuthorizedClient(next, authorizedClient, r)); } return Mono.just(authorizedClient); } private Mono refreshAuthorizedClient(ExchangeFunction next, - OAuth2AuthorizedClient authorizedClient, ServerWebExchange exchange) { + OAuth2AuthorizedClient authorizedClient, OAuth2AuthorizedClientResolver.Request r) { + ServerWebExchange exchange = r.getExchange(); + Authentication authentication = r.getAuthentication(); ClientRegistration clientRegistration = authorizedClient .getClientRegistration(); String tokenUri = clientRegistration .getProviderDetails().getTokenUri(); - ClientRequest request = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) + ClientRequest refreshRequest = ClientRequest.create(HttpMethod.POST, URI.create(tokenUri)) .header(HttpHeaders.ACCEPT, MediaType.APPLICATION_JSON_VALUE) .headers(headers -> headers.setBasicAuth(clientRegistration.getClientId(), clientRegistration.getClientSecret())) .body(refreshTokenBody(authorizedClient.getRefreshToken().getTokenValue())) .build(); - return next.exchange(request) - .flatMap(response -> response.body(oauth2AccessTokenResponse())) + return next.exchange(refreshRequest) + .flatMap(refreshResponse -> refreshResponse.body(oauth2AccessTokenResponse())) .map(accessTokenResponse -> new OAuth2AuthorizedClient(authorizedClient.getClientRegistration(), authorizedClient.getPrincipalName(), accessTokenResponse.getAccessToken(), accessTokenResponse.getRefreshToken())) - .flatMap(result -> currentAuthentication() - .defaultIfEmpty(new PrincipalNameAuthentication(authorizedClient.getPrincipalName())) - .flatMap(principal -> this.authorizedClientRepository.saveAuthorizedClient(result, principal, exchange)) + .flatMap(result -> this.authorizedClientRepository.saveAuthorizedClient(result, authentication, exchange) .thenReturn(result)); } - private Mono currentAuthentication() { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(ANONYMOUS_USER_TOKEN); - } - private boolean shouldRefresh(OAuth2AuthorizedClient authorizedClient) { if (this.authorizedClientRepository == null) { return false; @@ -361,52 +290,4 @@ public final class ServerOAuth2AuthorizedClientExchangeFilterFunction implements .fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()) .with("refresh_token", refreshToken); } - - private static class PrincipalNameAuthentication implements Authentication { - private final String username; - - private PrincipalNameAuthentication(String username) { - this.username = username; - } - - @Override - public Collection getAuthorities() { - throw unsupported(); - } - - @Override - public Object getCredentials() { - throw unsupported(); - } - - @Override - public Object getDetails() { - throw unsupported(); - } - - @Override - public Object getPrincipal() { - throw unsupported(); - } - - @Override - public boolean isAuthenticated() { - throw unsupported(); - } - - @Override - public void setAuthenticated(boolean isAuthenticated) - throws IllegalArgumentException { - throw unsupported(); - } - - @Override - public String getName() { - return this.username; - } - - private UnsupportedOperationException unsupported() { - return new UnsupportedOperationException("Not Supported"); - } - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java index 0012353b82..35c1e2cdd0 100644 --- a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolver.java @@ -18,15 +18,9 @@ package org.springframework.security.oauth2.client.web.reactive.result.method.an import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotatedElementUtils; -import org.springframework.security.authentication.AnonymousAuthenticationToken; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.authority.AuthorityUtils; -import org.springframework.security.core.context.ReactiveSecurityContextHolder; -import org.springframework.security.core.context.SecurityContext; -import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; -import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -56,16 +50,18 @@ import reactor.core.publisher.Mono; * @see RegisteredOAuth2AuthorizedClient */ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMethodArgumentResolver { - private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private final OAuth2AuthorizedClientResolver authorizedClientResolver; /** * Constructs an {@code OAuth2AuthorizedClientArgumentResolver} using the provided parameters. * * @param authorizedClientRepository the authorized client repository */ - public OAuth2AuthorizedClientArgumentResolver(ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + public OAuth2AuthorizedClientArgumentResolver(ReactiveClientRegistrationRepository clientRegistrationRepository, ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); - this.authorizedClientRepository = authorizedClientRepository; + this.authorizedClientResolver = new OAuth2AuthorizedClientResolver(clientRegistrationRepository, authorizedClientRepository); + this.authorizedClientResolver.setDefaultOAuth2AuthorizedClient(true); } @Override @@ -80,41 +76,11 @@ public final class OAuth2AuthorizedClientArgumentResolver implements HandlerMeth RegisteredOAuth2AuthorizedClient authorizedClientAnnotation = AnnotatedElementUtils .findMergedAnnotation(parameter.getParameter(), RegisteredOAuth2AuthorizedClient.class); - Mono clientRegistrationId = Mono.justOrEmpty(authorizedClientAnnotation.registrationId()) - .filter(id -> !StringUtils.isEmpty(id)) - .switchIfEmpty(clientRegistrationId()) - .switchIfEmpty(Mono.defer(() -> Mono.error(new IllegalArgumentException( - "Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\").")))); + String clientRegistrationId = StringUtils.hasLength(authorizedClientAnnotation.registrationId()) ? + authorizedClientAnnotation.registrationId() : null; - Mono principal = ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .defaultIfEmpty(new AnonymousAuthenticationToken("key", "anonymous", - AuthorityUtils.createAuthorityList("ROLE_ANONYMOUS"))); - - Mono authorizedClient = Mono - .zip(clientRegistrationId, principal).switchIfEmpty( - clientRegistrationId.flatMap(id -> Mono.error(new IllegalStateException( - "Unable to resolve the Authorized Client with registration identifier \"" - + id - + "\". An \"authenticated\" or \"unauthenticated\" session is required. To allow for unauthenticated access, ensure ServerHttpSecurity.anonymous() is configured.")))) - .flatMap(zipped -> { - String registrationId = zipped.getT1(); - Authentication authentication = zipped.getT2(); - return this.authorizedClientRepository - .loadAuthorizedClient(registrationId, authentication, exchange).switchIfEmpty(Mono.defer(() -> Mono - .error(new ClientAuthorizationRequiredException( - registrationId)))); - }).cast(OAuth2AuthorizedClient.class); - - return authorizedClient.cast(Object.class); + return this.authorizedClientResolver.createDefaultedRequest(clientRegistrationId, null, exchange) + .flatMap(this.authorizedClientResolver::loadAuthorizedClient); }); } - - private Mono clientRegistrationId() { - return ReactiveSecurityContextHolder.getContext() - .map(SecurityContext::getAuthentication) - .filter(authentication -> authentication instanceof OAuth2AuthenticationToken) - .cast(OAuth2AuthenticationToken.class) - .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); - } } diff --git a/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java new file mode 100644 index 0000000000..a90f65e9c5 --- /dev/null +++ b/oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientResolver.java @@ -0,0 +1,186 @@ +/* + * Copyright 2002-2018 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.oauth2.client.web.reactive.result.method.annotation; + +import org.springframework.security.authentication.AnonymousAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.ReactiveSecurityContextHolder; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException; +import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; +import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.endpoint.OAuth2ClientCredentialsGrantRequest; +import org.springframework.security.oauth2.client.endpoint.ReactiveOAuth2AccessTokenResponseClient; +import org.springframework.security.oauth2.client.endpoint.WebClientReactiveClientCredentialsTokenResponseClient; +import org.springframework.security.oauth2.client.registration.ClientRegistration; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; +import org.springframework.security.oauth2.core.AuthorizationGrantType; +import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +import java.util.Optional; + +/** + * @author Rob Winch + * @since 5.1 + */ +class OAuth2AuthorizedClientResolver { + + private static final AnonymousAuthenticationToken ANONYMOUS_USER_TOKEN = new AnonymousAuthenticationToken("anonymous", "anonymousUser", + AuthorityUtils.createAuthorityList("ROLE_USER")); + + private final ReactiveClientRegistrationRepository clientRegistrationRepository; + + private final ServerOAuth2AuthorizedClientRepository authorizedClientRepository; + + private ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient = new WebClientReactiveClientCredentialsTokenResponseClient(); + + private boolean defaultOAuth2AuthorizedClient; + + public OAuth2AuthorizedClientResolver( + ReactiveClientRegistrationRepository clientRegistrationRepository, + ServerOAuth2AuthorizedClientRepository authorizedClientRepository) { + Assert.notNull(clientRegistrationRepository, "clientRegistrationRepository cannot be null"); + Assert.notNull(authorizedClientRepository, "authorizedClientRepository cannot be null"); + this.clientRegistrationRepository = clientRegistrationRepository; + this.authorizedClientRepository = authorizedClientRepository; + } + + /** + * If true, a default {@link OAuth2AuthorizedClient} can be discovered from the current Authentication. It is + * recommended to be cautious with this feature since all HTTP requests will receive the access token if it can be + * resolved from the current Authentication. + * @param defaultOAuth2AuthorizedClient true if a default {@link OAuth2AuthorizedClient} should be used, else false. + * Default is false. + */ + public void setDefaultOAuth2AuthorizedClient(boolean defaultOAuth2AuthorizedClient) { + this.defaultOAuth2AuthorizedClient = defaultOAuth2AuthorizedClient; + } + + /** + * Sets the {@link ReactiveOAuth2AccessTokenResponseClient} to be used for getting an {@link OAuth2AuthorizedClient} for + * client_credentials grant. + * @param clientCredentialsTokenResponseClient the client to use + */ + public void setClientCredentialsTokenResponseClient( + ReactiveOAuth2AccessTokenResponseClient clientCredentialsTokenResponseClient) { + Assert.notNull(clientCredentialsTokenResponseClient, "clientCredentialsTokenResponseClient cannot be null"); + this.clientCredentialsTokenResponseClient = clientCredentialsTokenResponseClient; + } + + Mono createDefaultedRequest(String clientRegistrationId, + Authentication authentication, ServerWebExchange exchange) { + Mono defaultedAuthentication = Mono.justOrEmpty(authentication) + .switchIfEmpty(currentAuthentication()); + + Mono defaultedRegistrationId = Mono.justOrEmpty(clientRegistrationId) + .switchIfEmpty(clientRegistrationId(defaultedAuthentication)) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("The clientRegistrationId could not be resolved. Please provide one"))); + + Mono> defaultedExchange = Mono.justOrEmpty(exchange) + .switchIfEmpty(currentServerWebExchange()).map(Optional::of) + .defaultIfEmpty(Optional.empty()); + + return Mono.zip(defaultedRegistrationId, defaultedAuthentication, defaultedExchange) + .map(t3 -> new Request(t3.getT1(), t3.getT2(), t3.getT3().orElse(null))); + } + + Mono loadAuthorizedClient(Request request) { + String clientRegistrationId = request.getClientRegistrationId(); + Authentication authentication = request.getAuthentication(); + ServerWebExchange exchange = request.getExchange(); + return this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, exchange) + .switchIfEmpty(authorizedClientNotLoaded(clientRegistrationId, authentication, exchange)); + } + + private Mono authorizedClientNotLoaded(String clientRegistrationId, Authentication authentication, ServerWebExchange exchange) { + return this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId) + .switchIfEmpty(Mono.error(() -> new IllegalArgumentException("Client Registration with id " + clientRegistrationId + " was not found"))) + .flatMap(clientRegistration -> { + if (AuthorizationGrantType.CLIENT_CREDENTIALS.equals(clientRegistration.getAuthorizationGrantType())) { + return clientCredentials(clientRegistration, authentication, exchange); + } + return Mono.error(() -> new ClientAuthorizationRequiredException(clientRegistrationId)); + }); +} + + private Mono clientCredentials( + ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange) { + OAuth2ClientCredentialsGrantRequest grantRequest = new OAuth2ClientCredentialsGrantRequest(clientRegistration); + return this.clientCredentialsTokenResponseClient.getTokenResponse(grantRequest) + .flatMap(tokenResponse -> clientCredentialsResponse(clientRegistration, authentication, exchange, tokenResponse)); + } + + private Mono clientCredentialsResponse(ClientRegistration clientRegistration, Authentication authentication, ServerWebExchange exchange, OAuth2AccessTokenResponse tokenResponse) { + OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient( + clientRegistration, authentication.getName(), tokenResponse.getAccessToken()); + return this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, authentication, exchange) + .thenReturn(authorizedClient); + } + + /** + * Attempts to load the client registration id from the current {@link Authentication} + * @return + */ + private Mono clientRegistrationId(Mono authentication) { + return authentication + .filter(t -> this.defaultOAuth2AuthorizedClient && t instanceof OAuth2AuthenticationToken) + .cast(OAuth2AuthenticationToken.class) + .map(OAuth2AuthenticationToken::getAuthorizedClientRegistrationId); + } + + private Mono currentAuthentication() { + return ReactiveSecurityContextHolder.getContext() + .map(SecurityContext::getAuthentication) + .defaultIfEmpty(ANONYMOUS_USER_TOKEN); + } + + private Mono currentServerWebExchange() { + return Mono.subscriberContext() + .filter(c -> c.hasKey(ServerWebExchange.class)) + .map(c -> c.get(ServerWebExchange.class)); + } + + static class Request { + private final String clientRegistrationId; + private final Authentication authentication; + private final ServerWebExchange exchange; + + public Request(String clientRegistrationId, Authentication authentication, + ServerWebExchange exchange) { + this.clientRegistrationId = clientRegistrationId; + this.authentication = authentication; + this.exchange = exchange; + } + + public String getClientRegistrationId() { + return this.clientRegistrationId; + } + + public Authentication getAuthentication() { + return this.authentication; + } + + public ServerWebExchange getExchange() { + return this.exchange; + } + } +} diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java index 74ae319a8b..9faab5ba88 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java @@ -16,6 +16,7 @@ package org.springframework.security.oauth2.client.web.reactive.function.client; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.Mock; @@ -88,7 +89,7 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Mock private ServerWebExchange serverWebExchange; - private ServerOAuth2AuthorizedClientExchangeFilterFunction function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(); + private ServerOAuth2AuthorizedClientExchangeFilterFunction function; private MockExchangeFunction exchange = new MockExchangeFunction(); @@ -100,6 +101,11 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { Instant.now(), Instant.now().plus(Duration.ofDays(1))); + @Before + public void setup() { + this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); + } + @Test public void filterWhenAuthorizedClientNullThenAuthorizationHeaderNull() { ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -155,7 +161,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -204,7 +209,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { this.accessToken.getTokenValue(), issuedAt, accessTokenExpiresAt); - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", issuedAt, refreshTokenExpiresAt); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, @@ -236,8 +240,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenRefreshTokenNullThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken); ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) @@ -258,8 +260,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenNotExpiredThenShouldRefreshFalse() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); @@ -281,8 +281,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdThenAuthorizedClientResolved() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); @@ -306,7 +304,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdFromAuthenticationThenAuthorizedClientResolved() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); this.function.setDefaultOAuth2AuthorizedClient(true); OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); @@ -337,8 +334,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenDefaultOAuth2AuthorizedClientFalseThenEmpty() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com")) .build(); @@ -359,8 +354,6 @@ public class ServerOAuth2AuthorizedClientExchangeFilterFunctionTests { @Test public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenServerWebExchangeFromContext() { - this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(this.clientRegistrationRepository, this.authorizedClientRepository); - OAuth2RefreshToken refreshToken = new OAuth2RefreshToken("refresh-token", this.accessToken.getIssuedAt(), this.accessToken.getExpiresAt()); OAuth2AuthorizedClient authorizedClient = new OAuth2AuthorizedClient(this.registration, "principalName", this.accessToken, refreshToken); diff --git a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java index 762ef63a57..54b3d45fba 100644 --- a/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java +++ b/oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/result/method/annotation/OAuth2AuthorizedClientArgumentResolverTests.java @@ -29,9 +29,10 @@ import org.springframework.security.oauth2.client.ClientAuthorizationRequiredExc import org.springframework.security.oauth2.client.OAuth2AuthorizedClient; import org.springframework.security.oauth2.client.annotation.RegisteredOAuth2AuthorizedClient; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; +import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; +import org.springframework.security.oauth2.client.registration.TestClientRegistrations; import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository; import org.springframework.util.ReflectionUtils; -import reactor.core.publisher.Hooks; import reactor.core.publisher.Mono; import reactor.util.context.Context; @@ -50,6 +51,8 @@ import static org.mockito.Mockito.when; */ @RunWith(MockitoJUnitRunner.class) public class OAuth2AuthorizedClientArgumentResolverTests { + @Mock + private ReactiveClientRegistrationRepository clientRegistrationRepository; @Mock private ServerOAuth2AuthorizedClientRepository authorizedClientRepository; private OAuth2AuthorizedClientArgumentResolver argumentResolver; @@ -59,15 +62,14 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Before public void setUp() { - this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.authorizedClientRepository); + this.argumentResolver = new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, this.authorizedClientRepository); this.authorizedClient = mock(OAuth2AuthorizedClient.class); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.just(this.authorizedClient)); - Hooks.onOperatorDebug(); } @Test public void constructorWhenOAuth2AuthorizedClientServiceIsNullThenThrowIllegalArgumentException() { - assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(null)) + assertThatThrownBy(() -> new OAuth2AuthorizedClientArgumentResolver(this.clientRegistrationRepository, null)) .isInstanceOf(IllegalArgumentException.class); } @@ -94,11 +96,13 @@ public class OAuth2AuthorizedClientArgumentResolverTests { MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> resolveArgument(methodParameter)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("Unable to resolve the Client Registration Identifier. It must be provided via @RegisteredOAuth2AuthorizedClient(\"client1\") or @RegisteredOAuth2AuthorizedClient(registrationId = \"client1\")."); + .hasMessage("The clientRegistrationId could not be resolved. Please provide one"); } @Test public void resolveArgumentWhenRegistrationIdEmptyAndOAuth2AuthenticationThenResolves() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); this.authentication = mock(OAuth2AuthenticationToken.class); when(((OAuth2AuthenticationToken) this.authentication).getAuthorizedClientRegistrationId()).thenReturn("client1"); MethodParameter methodParameter = this.getMethodParameter("registrationIdEmpty", OAuth2AuthorizedClient.class); @@ -108,18 +112,24 @@ public class OAuth2AuthorizedClientArgumentResolverTests { @Test public void resolveArgumentWhenParameterTypeOAuth2AuthorizedClientAndCurrentAuthenticationNullThenResolves() { this.authentication = null; + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientFoundThenResolves() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThat(resolveArgument(methodParameter)).isSameAs(this.authorizedClient); } @Test public void resolveArgumentWhenOAuth2AuthorizedClientNotFoundThenThrowClientAuthorizationRequiredException() { + when(this.clientRegistrationRepository.findByRegistrationId(any())).thenReturn(Mono.just( + TestClientRegistrations.clientRegistration().build())); when(this.authorizedClientRepository.loadAuthorizedClient(anyString(), any(), any())).thenReturn(Mono.empty()); MethodParameter methodParameter = this.getMethodParameter("paramTypeAuthorizedClient", OAuth2AuthorizedClient.class); assertThatThrownBy(() -> resolveArgument(methodParameter))